Skip to content

Generating the batches seems slow #37

Open
@robintw

Description

@robintw

I've just come across xbatcher, and I think it could be just what I need for using CNNs on data stored in dask-backed xarrays. I've got a number of questions about how it works, and some issues I'm having. If this isn't the appropriate place for these questions then please let me know, and I'll direct them elsewhere. I decided not to create issues for each question, as I expect a number of them aren't actually problems with xbatcher, they're problems with my understanding instead - that would clog up the issues board - but if some of these questions need extracting to a separate issue then I'm happy to do that.

Firstly, thanks for putting this together - it has already solved a lot of problems for me.

To give some context, I'm trying to use xbatcher to run batches of image data through a pytorch CNN on Microsoft Planetary Computer. I'm not doing training here, I'm just doing inference - so I just need to push the raw array data through the ML model and get the results out.

Now on to the questions:

1. Generating the batches seems slow
I'm trying to create batches from a DataArray of size (8172, 8082), which is a single band of a satellite image. I'm using the following call to create a BatchGenerator:

patch_size = 64
batch_size = 10
bgen = b1.batch.generator(input_dims=dict(x=patch_size, y=patch_size),
                          batch_dims=dict(x=batch_size*patch_size, y=batch_size*patch_size),
                          concat_input_dims=True, preload_batch=False)

That should create DataArrays that are 64 x 64 (in x and y), with 100 of those entries in the batch.

I'm then running a loop over the batch generator, doing something with the batches. We'll come to what I'm doing later - but for the moment lets just append the result to a list:

results = []
for batch in tqdm.tqdm(bgen):
    results.append(batch)

This takes around 1s per batch, and creates a very small Dask task that goes away and generates the batch (I've already run b1.persist() to ensure all the data is on the Dask cluster). I have a few questions about this:

a) Is this sort of speed expected? From some rough calculations at 1s per batch, for a 64 x 64 batch, it'll take hours to batch up my ~8000x8000 array)
b) With preload_batch=False I'd expect these to be generated lazily - and it does seem that the underlying data in the DataArray is a dask array - however it still seems to take around a second per batch.
c) Should I be approaching this in a different way to get a better speed?

2. How do you put batches back together after processing?
My machine learning model is producing a single value as an output, so for a batch of 100 64x64 patches, I get an output of a 100-element array. What's the best way of putting this back into a DataArray that has the same format/co-ordinates as the original input array? I'd be happy with either an array with dimensions of original_size / 64 in both the x and y dimension, or an array of the same size as the input with the single output value repeated for each of the input pixels in that batch.

I've tried to put some of this together myself, but it seems that the x co-ordinate value in the batch DataArray is the same for each batch. I'd have thought this would represent the x co-ordinates that had been extracted from the original DataArray, but it doesn't seem to. For example, if I run:

batches = []
for i, batch in enumerate(bgen):
    batches.append(batch)
    if i == 1:
        break

to get the first two batches, I can then compare their x co-ordinate values:

np.all(batches[0].to_array().squeeze().x == batches[1].to_array().squeeze().x)

and it shows that they're all equal.

Do you have any ideas as to what I could do to be able to put the batches back together?

3. Documentation and tutorial notebook
It took me quite a while to find the example notebook that is sitting in #31 - but the notebook was really helpful (actually a lot more helpful than the documentation on ReadTheDocs). Could this be merged soon, and a prominent link put to it in the docs/README? I think this would significantly help any other users trying to get to grips with xbatcher.

4. Overlap seems to cause hang
Trying to batch the array with an overlap seems to take ages to do anything - I'm not sure whether it has hung or is just taking a long time to do the computations. If I run:

patch_size = 64
batch_size = 10
bgen = b1.batch.generator(input_dims=dict(x=patch_size, y=patch_size),
                          batch_dims=dict(x=batch_size*patch_size, y=batch_size*patch_size),
                          input_overlap=dict(x=patch_size-1, y=patch_size-1),
                          concat_input_dims=True,
                          preload_batch=False)

and then try and get the first two batches:

batches = []
for i, batch in enumerate(bgen):
    batches.append(batch)
    if i == 1:
        break

I leave it running for a few minutes, and nothing seems to happen. When I interrupt it, it seems to be deep inside xarray/pandas index handling code.

Any idea what's going on?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions