Skip to content

Cache batches  #109

Open
Open
@jhamman

Description

@jhamman

Is your feature request related to a problem?

Generating batches can be slow. A typical dataset used in Xbatcher originates in either cloud storage or a local file system, is loaded into Xarray lazily using Dask, and includes some lazily processing. Then Xbatcher does some indexing/transposing/concatenation on the underlying arrays and loads the data into what we think of as a batch. When batching through a dataset multiple times, it is often desirable to cache batches to speed up loading batches later on (e.g. during training a ML model).

It would be nice if Xbatcher included some features that made the process of generating and caching batches less painful.

Describe the solution you'd like

After discussing this with @maxrjones and @norlandrhagen, I think there are two leading approaches to caching that Xbatcher should explore.

  1. Explicitly dump a generator to a cache

This option is similar to what @leifdenby has proposed in #40. This would require users to explicitly pre-process their generator into something optimized for batch generation. In practice, it would look something like this:

# create a batch generator
bgen = BatchGenerator(da, {'time': 10})
# dump to a zarr store
bgen.to_zarr('s3://foo/bar.zarr')

# load from a zarr store
bgen = BatchGenerator.from_zarr('s3://foo/bar.zarr')
for batch in bgen:
    # use batch ...

I like this approach because it produces a zarr store that could be easily interrogated by an interested user. The main downside is that it requires an explicit cache dump step.

  1. Attach a configurable cache to the BatchGenerator class

This option would push cache management inside the generator itself such that the generation of batches would first check if the batch exists in a configurable cache. The top level API cloud look something like this:

# user creates a cache that follows a mutuable mapping interface
cache = zict.Buffer({}, zarr.storage.FSStore('s3://foo'))
bgen = BatchGenerator(da, {'time': 10}, cache=cache)
for batch in bgen:
    # use batch ... (caching happens behind the scenes on a batch by batch basis)

for batch in bgen:
    # this will be much faster since all batches are all cached

# get a single batch from the cache
bgen.cache[12]

I like this approach because it has the potential to be highly configurable (when needed) and does not require a manual cache dump. The main downside I see is that the cache will be split into a bunch of small datasets (probably zarr stores).

Describe alternatives you've considered

We should consider not supporting caching in Xbatcher but instead develop a few recipes for how to use Dask caching (e.g. https://docs.dask.org/en/stable/caching.html or https://github.com/radix-ai/graphchain) or caching at the data loader level (e.g. https://www.tensorflow.org/api_docs/python/tf/data/Dataset#cache).

See also this interesting conversation where this feature is discussed in the Pytorch context: pytorch/pytorch#35642

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions