diff --git a/data_provider/data_factory.py b/data_provider/data_factory.py index 7af7698a..7fc458f6 100644 --- a/data_provider/data_factory.py +++ b/data_provider/data_factory.py @@ -23,7 +23,7 @@ def data_provider(args, flag): Data = data_dict[args.data] timeenc = 0 if args.embed != 'timeF' else 1 - shuffle_flag = False if flag == 'test' else True + shuffle_flag = False if (flag == 'test' or flag == 'TEST') else True drop_last = False batch_size = args.batch_size freq = args.freq