-
-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improves MinariDataset load speed to speed up sampling, fixes total_steps bug and adds test coverage #129
Improves MinariDataset load speed to speed up sampling, fixes total_steps bug and adds test coverage #129
Changes from all commits
15336ff
0d07d90
a7d3f5f
4c0f466
2c9c0ab
72d379d
a5d9e7b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -205,13 +205,13 @@ def update_from_collector_env( | |
"id", last_episode_id + id | ||
) | ||
|
||
self._total_steps = file.attrs["total_steps"] + new_data_total_steps | ||
|
||
# Update metadata of minari dataset | ||
file.attrs.modify( | ||
"total_episodes", last_episode_id + new_data_total_episodes | ||
) | ||
file.attrs.modify( | ||
"total_steps", file.attrs["total_steps"] + new_data_total_steps | ||
) | ||
file.attrs.modify("total_steps", self._total_steps) | ||
self._total_episodes = int(file.attrs["total_episodes"].item()) | ||
|
||
def update_from_buffer(self, buffer: List[dict], data_path: str): | ||
|
@@ -247,10 +247,11 @@ def update_from_buffer(self, buffer: List[dict], data_path: str): | |
|
||
# TODO: save EpisodeMetadataCallback callback in MinariDataset and update new episode group metadata | ||
|
||
file.attrs.modify("total_episodes", last_episode_id + len(buffer)) | ||
file.attrs.modify( | ||
"total_steps", file.attrs["total_steps"] + additional_steps | ||
) | ||
self._total_steps = file.attrs["total_steps"] + additional_steps | ||
self._total_episodes = last_episode_id + len(buffer) | ||
|
||
file.attrs.modify("total_episodes", self._total_episodes) | ||
file.attrs.modify("total_steps", self._total_steps) | ||
Comment on lines
-250
to
+254
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for these fixes also. I am wondering if we need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it might marginally save on file reads if we keep it the way it is for some use-cases, but it doesn't seem like it would make a super big difference either way. |
||
|
||
self._total_episodes = int(file.attrs["total_episodes"].item()) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also means that creating subsets of the dataset can be very slow
Previously, I did a lazy initialization for total_steps attribute; we removed it to include it in the
MinariDatasetSpec
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can maybe consider having metadata you can load without having to load the corresponding episode in a future PR, though I'm not super opinionated either way at this point.