Description
Is your feature request related to a problem?
Generating batches can be slow. A typical dataset used in Xbatcher originates in either cloud storage or a local file system, is loaded into Xarray lazily using Dask, and includes some lazily processing. Then Xbatcher does some indexing/transposing/concatenation on the underlying arrays and loads the data into what we think of as a batch. When batching through a dataset multiple times, it is often desirable to cache batches to speed up loading batches later on (e.g. during training a ML model).
It would be nice if Xbatcher included some features that made the process of generating and caching batches less painful.
Describe the solution you'd like
After discussing this with @maxrjones and @norlandrhagen, I think there are two leading approaches to caching that Xbatcher should explore.
- Explicitly dump a generator to a cache
This option is similar to what @leifdenby has proposed in #40. This would require users to explicitly pre-process their generator into something optimized for batch generation. In practice, it would look something like this:
# create a batch generator
bgen = BatchGenerator(da, {'time': 10})
# dump to a zarr store
bgen.to_zarr('s3://foo/bar.zarr')
# load from a zarr store
bgen = BatchGenerator.from_zarr('s3://foo/bar.zarr')
for batch in bgen:
# use batch ...
I like this approach because it produces a zarr store that could be easily interrogated by an interested user. The main downside is that it requires an explicit cache dump step.
- Attach a configurable cache to the BatchGenerator class
This option would push cache management inside the generator itself such that the generation of batches would first check if the batch exists in a configurable cache. The top level API cloud look something like this:
# user creates a cache that follows a mutuable mapping interface
cache = zict.Buffer({}, zarr.storage.FSStore('s3://foo'))
bgen = BatchGenerator(da, {'time': 10}, cache=cache)
for batch in bgen:
# use batch ... (caching happens behind the scenes on a batch by batch basis)
for batch in bgen:
# this will be much faster since all batches are all cached
# get a single batch from the cache
bgen.cache[12]
I like this approach because it has the potential to be highly configurable (when needed) and does not require a manual cache dump. The main downside I see is that the cache will be split into a bunch of small datasets (probably zarr stores).
Describe alternatives you've considered
We should consider not supporting caching in Xbatcher but instead develop a few recipes for how to use Dask caching (e.g. https://docs.dask.org/en/stable/caching.html or https://github.com/radix-ai/graphchain) or caching at the data loader level (e.g. https://www.tensorflow.org/api_docs/python/tf/data/Dataset#cache).
See also this interesting conversation where this feature is discussed in the Pytorch context: pytorch/pytorch#35642
Additional context
No response