Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

一些问题和优化建议 #139

Open
TensorPulse opened this issue Aug 22, 2024 · 26 comments
Open

一些问题和优化建议 #139

TensorPulse opened this issue Aug 22, 2024 · 26 comments

Comments

@TensorPulse
Copy link

您好,作者,感谢提供如此完整的学习框架!本人在使用和移植基线的过程中遇到一些问题和不便的地方,在此提出来以便您参考优化。
声明:以下问题和建议仅代表个人看法,仅供参考
问题:利用pycham直接运行data_preparation显示找不到数据集文件,运行train时也一样,做如下修改就可以运行:
OUTPUT_DIR = "../../../experiments/datasets/" + DATASET_NAME
DATA_FILE_PATH = "../../../datasets/raw_data/{0}/{0}.npz".format(DATASET_NAME)
GRAPH_FILE_PATH = "../../../datasets/raw_data/{0}/adj_{0}".format(DATASET_NAME)
DISTANCE_FILE_PATH = "../../../datasets/raw_data/{0}/distance_{0}".format(DATASET_NAME)
优化建议:
1.数据集的归一化和反归一化:CFG.RESCALE:如果为True,表示既反归一化数据又将整个数据的标准化,如果为False,表示既不反归一化数据又将数据的每个通道标准化。可以拆解为两个变量,一个变量控制数据的标准化,一个变量控制数据和归一化和反归一化。
2. 模型训练结果表示不清:用模型名+epochs的方式所表达的直接信息不全,可做如下修改:
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join(
"checkpoints",
CFG.MODEL.NAME,
"_".join([CFG.DATASET_NAME, str(CFG.TRAIN.NUM_EPOCHS)])
)
3.项目的可视化接口不足:可在tensorboard中增加一些指标或增加预测数据保存的接口
4.项目cfg文件中有许多隐藏接口,可以添加一个Simple_CFG将所有接口表达出来,例如:
CFG.MODEL.SETUP_GRAPH = False
CFG.TRAIN.FINETUNE_FROM
CFG.RESCALE = True
5.在基线STGODE中,需要引入A_sp_hat, A_se_hat两个张量,发现即使将整个模型放入gpu中,这两个张量仍然存在于cpu中,直到后续.to(x.device)。这在模型的移植中不太便利,需要找到张量最终使用的地方。建议使用
from easytorch.device import get_device_type
if get_device_type() == 'gpu':
device = 'cuda'
else:
device = 'cpu'
self.device = device
或者from easytorch.device import to_device
6.在test过程中没有进度条显示,可修改:

tqdm process bar

data_iter = tqdm(self.test_data_loader)

test loop

for iter_index, data in enumerate(data_iter):

@zezhishao
Copy link
Collaborator

非常感谢您的建议,这对我们帮助很大!
我们一直计划整体升级一版代码,但一直苦于没有时间。再次感谢您的建议,我们后续会一一修改。若您有一些已经完成的修改,可以通过PR的方式合并到主目录,成为BasicTS的开发者~

@zezhishao
Copy link
Collaborator

"tensorboard中增加一些指标",您指的是什么呢,可否给出一些具体的需求?

@TensorPulse
Copy link
Author

比如模型的计算图,框架图,权重,偏差随时间变化的直方图。
预测值,真实值,历史值的可视化,将嵌入投射到低维空间的可视化等等

@TensorPulse
Copy link
Author

个人觉得basicts论文中对ETT系列数据集的标准化处理有一点问题,因为ETT数据集的各列数据不像PEMS数据集是同一单位,导致各列的值差异有时候差异较大,采用各个通道标准化可能比整体标准化更合适,若需要在计算指标时反归一化,可参考优化建议第一条

@zezhishao
Copy link
Collaborator

zezhishao commented Aug 27, 2024

好的,感谢您的建议。我正在开发新版的BasicTS,基本已经完成,会在未来几天内发布。
后续用户将可以在训练的时候即时指定归一化方式和是否反归一化,而无需提前预处理数据。

@TensorPulse
Copy link
Author

作者您好,BasicTS能支持自动调参不?
参考链接:https://zhuanlan.zhihu.com/p/401190615?utm_id=0
https://github.com/LibCity/Bigscity-LibCity

@zezhishao
Copy link
Collaborator

您好,下一版本没有囊括自动调参功能,我对自动调参这块不是很熟悉,从我自己的经验来看这几个数据集似乎对超参数没那么敏感?

@huiguhean
Copy link

请问一下,运行train找不到FileNotFoundError: [Errno 2] No such file or directory: 'datasets/ETTh1/scaler_in_96_out_336_rescale_True.pkl',修改哪个文件夹

@huiguhean
Copy link

修改了baselines里面对应的模型数据集好了

@zezhishao
Copy link
Collaborator

您好,在目前的版本下,您还需要手动生成不同输入输出长度的数据集。您可以通过下面的指令生成:

python scripts/data_preparation/${DATASET_NAME}/generate_training_data.py --history_seq_len ${INPUT_LEN} --future_seq_len ${OUTPUT_LEN}

例如:

python scripts/data_preparation/ETTh1/generate_training_data.py  --history_seq_len 96 --future_seq_len 336

马上会更新一个版本,可以在训练的时候即时指定,敬请期待~

@huiguhean
Copy link

伟大,无需多言!

@TensorPulse
Copy link
Author

作者您好,下个版本是否有支持单变量预测的接口?

@zezhishao
Copy link
Collaborator

您所说的单变量指的是什么?是指只有一条时间序列的数据集吗?

@TensorPulse
Copy link
Author

不好意思,好像只需要重新定义runner即可。单变量预测指的是带OT变量的数据集,例如ETT

@TensorPulse
Copy link
Author

作者您好,我想请教一下BasicTS论文中长时序预测结果的历史长度和预测长度,我看代码里给的不同模型的历史长度似乎存在不一样,预测长度应该都是336,所以我想确定一下

@zezhishao
Copy link
Collaborator

是的,事实上不同的论文对于历史长度的规定是不一样的,而不同方法的最优历史长度也不一样。
我们的做法是在几个常用的历史长度中采用效果最好的那一个。

@morestart
Copy link

I noticed that easytorch is no longer maintained. Are you considering switching the backend in the new version?

@zezhishao
Copy link
Collaborator

Currently, EasyTorch is still able to meet the needs of BasicTS, so there won't be any changes in the short term. However, in the longer term, I hope that the backend of BasicTS will no longer need to rely on other packages, although this will be time-consuming.

Do you have any other needs that the current EasyTorch backend cannot satisfy?"

@zezhishao
Copy link
Collaborator

大家好,BasicTS代码已更新,欢迎大家查看并使用!

Hello, everyone! The BasicTS code has been updated. Feel free to check it out and use it!

@TensorPulse
Copy link
Author

作者您好,关于PEMS序列数据集的图结构是否有方向性?我看代码里面默认是无向图,是否可以提供一个有向图,无向图的可选项?还是说数据集本身就是无向图?
参考代码:
i, j, distance = int(row[0]), int(row[1]), float(row[2])
adjacency_matrix_connectivity[id_dict[i], id_dict[j]] = 1
adjacency_matrix_distance[id_dict[i],
id_dict[j]] = distance
if not directed:
adjacency_matrix_connectivity[id_dict[j], id_dict[i]] = 1
adjacency_matrix_distance[id_dict[j],
id_dict[i]] = distance

@zezhishao
Copy link
Collaborator

您好,PEMS0X的数据集处理脚本是数据集自带的。你可以通过在github上搜索:A[id_dict[i], id_dict[j]] = 1,你可以找到很多仓库的实现代码,比如STGCN、ASTGCN、STSGCN、STFGCN等。

@TensorPulse
Copy link
Author

作者您好,数据可视化的代码是否可以更新一下?

@zezhishao
Copy link
Collaborator

好的,忘记更新了,明天更新啊

@TensorPulse
Copy link
Author

complete_config文件中CFG.DATASET.PARAM中的overlap参数未给出。
关于重叠数据集的划分是否合理?以pems08数据集为例,在验证集与测试集比例相同的情况下,验证集比测试集的多了11条数据

@zezhishao
Copy link
Collaborator

您好,感谢您的建议,overlap参数目前已经设置默认为False并自动调整,且给出警告。
默认情况下,overlap设置为False。
当Train/Valid/Test数据对应的原始数据长度过短,无法形成足够样本的时候(例如Illeness数据集),会自动启用overlap,并给出提示。例如STID运行Illeness数据集时会产生如下log:

2024-09-13 10:27:07,383 - easytorch-launcher - INFO - Launching EasyTorch training.
DESCRIPTION: An Example Config
GPU_NUM: 1
RUNNER: <class 'basicts.runners.runner_zoo.simple_tsf_runner.SimpleTimeSeriesForecastingRunner'>
DATASET:
  NAME: Illness
  TYPE: <class 'basicts.data.simple_tsf_dataset.TimeSeriesForecastingDataset'>
  PARAM:
    dataset_name: Illness
    train_val_test_ratio: [0.7, 0.1, 0.2]
    input_len: 96
    output_len: 48
SCALER:
  TYPE: <class 'basicts.scaler.z_score_scaler.ZScoreScaler'>
  PARAM:
    dataset_name: Illness
    train_ratio: 0.7
    norm_each_channel: True
    rescale: False
MODEL:
  NAME: STID
  ARCH: <class 'baselines.STID.arch.stid_arch.STID'>
  PARAM:
    num_nodes: 7
    input_len: 96
    input_dim: 1
    embed_dim: 2048
    output_len: 48
    num_layer: 1
    if_node: True
    node_dim: 32
    if_T_i_D: True
    if_D_i_W: True
    temp_dim_tid: 8
    temp_dim_diw: 8
    time_of_day_size: 1
    day_of_week_size: 7
  FORWARD_FEATURES: [0, 1, 2]
  TARGET_FEATURES: [0]
METRICS:
  FUNCS:
    MAE: masked_mae
    MSE: masked_mse
  TARGET: MAE
  NULL_VAL: nan
TRAIN:
  NUM_EPOCHS: 100
  CKPT_SAVE_DIR: checkpoints/STID/Illness_100_96_48
  LOSS: masked_mae
  OPTIM:
    TYPE: Adam
    PARAM:
      lr: 0.0005
      weight_decay: 0.0005
  LR_SCHEDULER:
    TYPE: MultiStepLR
    PARAM:
      milestones: [1, 3, 5]
      gamma: 0.1
  CLIP_GRAD_PARAM:
    max_norm: 5.0
  DATA:
    BATCH_SIZE: 64
    SHUFFLE: True
VAL:
  INTERVAL: 1
  DATA:
    BATCH_SIZE: 64
TEST:
  INTERVAL: 1
  DATA:
    BATCH_SIZE: 64
EVAL:
  USE_GPU: True

2024-09-13 10:27:07,451 - easytorch-env - INFO - Use devices 0.
2024-09-13 10:27:07,506 - easytorch-launcher - INFO - Initializing runner "<class 'basicts.runners.runner_zoo.simple_tsf_runner.SimpleTimeSeriesForecastingRunner'>"
2024-09-13 10:27:07,506 - easytorch-env - INFO - Disable TF32 mode
2024-09-13 10:27:07,506 - easytorch - INFO - Set ckpt save dir: 'checkpoints/STID/Illness_100_96_48/9cd15181d2d202a278536bfd1f1031a0'
2024-09-13 10:27:07,506 - easytorch - INFO - Building model.
2024-09-13 10:27:07,747 - easytorch-training - INFO - Initializing training.
2024-09-13 10:27:07,747 - easytorch-training - INFO - Set clip grad, param: {'max_norm': 5.0}
2024-09-13 10:27:07,748 - easytorch-training - INFO - Building training data loader.
2024-09-13 10:27:07,748 - easytorch-training - INFO - Train dataset length: 534
2024-09-13 10:27:08,271 - easytorch-training - INFO - Set optim: Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.0005
    maximize: False
    weight_decay: 0.0005
)
2024-09-13 10:27:08,271 - easytorch-training - INFO - Set lr_scheduler: <torch.optim.lr_scheduler.MultiStepLR object at 0x7faaa1792550>
2024-09-13 10:27:08,271 - easytorch-training - INFO - Loading Checkpoint from 'checkpoints/STID/Illness_100_96_48/9cd15181d2d202a278536bfd1f1031a0/STID_100.pt'
2024-09-13 10:27:08,343 - easytorch-training - INFO - Resume training
2024-09-13 10:27:08,344 - easytorch-training - INFO - Initializing validation.
2024-09-13 10:27:08,345 - easytorch-training - INFO - Building val data loader.
2024-09-13 10:27:08,345 - easytorch-training - INFO - Validation dataset is too short, enabling overlap. See details in /home/S22/workspace/BasicTS/basicts/data/simple_tsf_dataset.py at line 96.
2024-09-13 10:27:08,345 - easytorch-training - INFO - Validation dataset length: 96
2024-09-13 10:27:08,383 - easytorch-training - INFO - Test dataset length: 50
2024-09-13 10:27:08,384 - easytorch-training - INFO - Number of parameters: 9090224
2024-09-13 10:27:08,384 - easytorch-training - INFO - The training finished at 2024-09-13 10:27:08
2024-09-13 10:27:08,384 - easytorch-training - INFO - Evaluating the best model on the test set.
2024-09-13 10:27:08,384 - easytorch-training - INFO - Loading Checkpoint from 'checkpoints/STID/Illness_100_96_48/9cd15181d2d202a278536bfd1f1031a0/STID_best_val_MAE.pt'
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  7.40it/s]
2024-09-13 10:27:08,657 - easytorch-training - INFO - Result <test>: [test_time: 0.21 (s), test_MAE: 1.3040, test_MSE: 3.2762]
2024-09-13 10:27:08,658 - easytorch-training - INFO - Test results saved to checkpoints/STID/Illness_100_96_48/9cd15181d2d202a278536bfd1f1031a0/test_results.npz.
2024-09-13 10:27:08,659 - easytorch-training - INFO - Test metrics saved to checkpoints/STID/Illness_100_96_48/9cd15181d2d202a278536bfd1f1031a0/test_metrics.json.

Validation dataset is too short, enabling overlap. See details in /home/S22/workspace/BasicTS/basicts/data/simple_tsf_dataset.py at line 96. 代表此时测试集不够长,因此将Validation数据集的overlap设置为True。

@zezhishao
Copy link
Collaborator

作者您好,数据可视化的代码是否可以更新一下?

已更新

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants