Skip to content

Commit

Permalink
Fix TSDataSampler Slicing Bug #1716 (#1803)
Browse files Browse the repository at this point in the history
* Fix TSDataSampler Slicing Bug #1716

* Fix TSDataSampler Slicing Bug #1716

* Fix TSDataSampler Slicing Bug #1716

* Fix TSDataSampler Slicing Bug with simplyer implmentation#1716
 with Simplified Implementation

* Refactor: Fix CI errors by addressing pylint formatting issues

* Refactor: Remove extraneous whitespace for improved code formatting with Black
  • Loading branch information
YeewahChan authored Jun 21, 2024
1 parent 3a348ae commit ebc0ca8
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
2 changes: 1 addition & 1 deletion qlib/data/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def __init__(
np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype),
axis=0,
)
self.nan_idx = -1 # The last line is all NaN
self.nan_idx = len(self.data_arr) - 1 # The last line is all NaN; setting it to -1 can cause bug #1716

# the data type will be changed
# The index of usable data is between start_idx and end_idx
Expand Down
51 changes: 50 additions & 1 deletion tests/data_mid_layer_tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import pytest
import sys
from qlib.tests import TestAutoData
from qlib.data.dataset import TSDatasetH
from qlib.data.dataset import TSDatasetH, TSDataSampler
import numpy as np
import pandas as pd
import time
from qlib.data.dataset.handler import DataHandlerLP

Expand Down Expand Up @@ -98,6 +99,54 @@ def testTSDataset(self):
print(idx[i])


class TestTSDataSampler(unittest.TestCase):
def test_TSDataSampler(self):
"""
Test TSDataSampler for issue #1716
"""
datetime_list = ["2000-01-31", "2000-02-29", "2000-03-31", "2000-04-30", "2000-05-31"]
instruments = ["000001", "000002", "000003", "000004", "000005"]
index = pd.MultiIndex.from_product(
[pd.to_datetime(datetime_list), instruments], names=["datetime", "instrument"]
)
data = np.random.randn(len(datetime_list) * len(instruments))
test_df = pd.DataFrame(data=data, index=index, columns=["factor"])
dataset = TSDataSampler(test_df, datetime_list[0], datetime_list[-1], step_len=2)
print()
print("--------------dataset[0]--------------")
print(dataset[0])
print("--------------dataset[1]--------------")
print(dataset[1])
assert len(dataset[0]) == 2
self.assertTrue(np.isnan(dataset[0][0]))
self.assertEqual(dataset[0][1], dataset[1][0])
self.assertEqual(dataset[1][1], dataset[2][0])
self.assertEqual(dataset[2][1], dataset[3][0])

def test_TSDataSampler2(self):
"""
Extra test TSDataSampler to prevent incorrect filling of nan for the values at the front
"""
datetime_list = ["2000-01-31", "2000-02-29", "2000-03-31", "2000-04-30", "2000-05-31"]
instruments = ["000001", "000002", "000003", "000004", "000005"]
index = pd.MultiIndex.from_product(
[pd.to_datetime(datetime_list), instruments], names=["datetime", "instrument"]
)
data = np.random.randn(len(datetime_list) * len(instruments))
test_df = pd.DataFrame(data=data, index=index, columns=["factor"])
dataset = TSDataSampler(test_df, datetime_list[2], datetime_list[-1], step_len=3)
print()
print("--------------dataset[0]--------------")
print(dataset[0])
print("--------------dataset[1]--------------")
print(dataset[1])
for i in range(3):
self.assertFalse(np.isnan(dataset[0][i]))
self.assertFalse(np.isnan(dataset[1][i]))
self.assertEqual(dataset[0][1], dataset[1][0])
self.assertEqual(dataset[0][2], dataset[1][1])


if __name__ == "__main__":
unittest.main(verbosity=10)

Expand Down

0 comments on commit ebc0ca8

Please sign in to comment.