-
-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improves MinariDataset load speed to speed up sampling, fixes total_steps bug and adds test coverage #129
Conversation
…_steps tests coverage and bugfix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not change requests, just comments
file.attrs.modify("total_episodes", last_episode_id + len(buffer)) | ||
file.attrs.modify( | ||
"total_steps", file.attrs["total_steps"] + additional_steps | ||
) | ||
self._total_steps = file.attrs["total_steps"] + additional_steps | ||
self._total_episodes = last_episode_id + len(buffer) | ||
|
||
file.attrs.modify("total_episodes", self._total_episodes) | ||
file.attrs.modify("total_steps", self._total_steps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for these fixes also. I am wondering if we need self._total_steps
or if we can just make the @property
total_steps reading from the file (and same for total_episodes)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it might marginally save on file reads if we keep it the way it is for some use-cases, but it doesn't seem like it would make a super big difference either way.
total_steps = sum( | ||
self._data.apply( | ||
lambda episode: episode["total_timesteps"], | ||
episode_indices=episode_indices, | ||
) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also means that creating subsets of the dataset can be very slow
Previously, I did a lazy initialization for total_steps attribute; we removed it to include it in the MinariDatasetSpec
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can maybe consider having metadata you can load without having to load the corresponding episode in a future PR, though I'm not super opinionated either way at this point.
Description
Previously, we would always iterate over all episodes to calculate total_steps when loading a MinariDataset. Now we will just load this from the MinariStorage when the MinariDataset is created, assuming no filter is specified. Empirically, this speeds up our sampling a lot because we no longer have to iterate over every episode regardless of how many we are sampling.
Since the patch required changing how total_steps is loaded, I also extended test coverage for
total_steps
and found some bugs in how it's calculate in some settings, which this PR also patches.Performance of Minari v.s. Pickle for sampling 10 episodes across different dataset sizes.
Performance prior to patch:
Type of change
Checklist:
pre-commit
checks withpre-commit run --all-files
(seeCONTRIBUTING.md
instructions to set it up)pytest -v
and no errors are present.pytest -v
has generated that are related to my code to the best of my knowledge.