Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Regridding xarray dataset with chunked dask-backed arrays #222

Closed
zoj613 opened this issue Jan 12, 2023 · 6 comments · Fixed by #280
Closed

Regridding xarray dataset with chunked dask-backed arrays #222

zoj613 opened this issue Jan 12, 2023 · 6 comments · Fixed by #280

Comments

@zoj613
Copy link

zoj613 commented Jan 12, 2023

Is there a way I can reliably regrid an xarray.Dataset object to a lower/higher resolution if it has variables with dask-backed chunked arrays. Every single time I try to use the output of the call to xesmf.Regridder to regrid the input data I get a

ValueError: Dimension 1 has 9 blocks, adjust_chunks specified with 1 blocks

exception. To get it to work, I have to force the datasets to have only a single chunk with .chunk(-1). This can cause tasks to fail when the dask graph is computed since a single chunk for large datasets can consume a lot of memory. Any workaround for this without using a single chunk?

@aulemahal
Copy link
Collaborator

aulemahal commented Jan 12, 2023

Sadly, not for now.
If your data has non-spatial dimensions, like time, I would suggest rechunking by merging all spatial chunks and splitting the time dimension so that the chunks sizes stay reasonable.
However, the regridding weights are currently stored in a single-chunk sparse matrix and distributing this across the chunks of your data is a complex problem...

@huard
Copy link
Contributor

huard commented Jan 16, 2023

If anyone has some ideas to solve this, that would be a great contribution.

@huard
Copy link
Contributor

huard commented Jan 20, 2023

See https://discourse.pangeo.io/t/conservative-region-aggregation-with-xarray-geopandas-and-sparse/2715 for possible solution.

@huard
Copy link
Contributor

huard commented Feb 3, 2023

We're hoping to have an intern work on this next summer. If anyone has tips to share, please leave them here.

@dcherian
Copy link
Contributor

dcherian commented Apr 26, 2023

Here's how to do it

Read (or convert) weights as pydata/sparse

def read_xesmf_weights_file(filename):
    import numpy as np
    import sparse
    import xarray as xr
    
    weights = xr.open_dataset(filename)

    # input variable shape
    in_shape = weights.src_grid_dims.load().data

    # output variable shape
    out_shape = weights.dst_grid_dims.load().data.tolist()[::-1]

    print(f"Regridding from {in_shape} to {out_shape}")

    rows = weights['row'] - 1 # row indices (1-based)
    cols = weights['col'] - 1 # col indices (1-based)
    
    # construct a sparse array,
    # reshape to 3D : lat, lon, ncol
    # This reshaping should allow optional chunking along
    # lat, lon later
    sparse_array_data = sparse.COO(
        coords=np.stack([rows.data, cols.data]), 
        data=weights.S.data, 
        shape=(weights.sizes["n_b"], weights.sizes["n_a"]), 
        fill_value=0,
      ).reshape((*out_shape, -1))
    
    # Create a DataArray with sparse weights and the output coordinates
    xsparse_wgts = xr.DataArray(
        sparse_array_data,
        dims=("lat", "lon", "ncol"),
        # Add useful coordinate information, this will get propagated to the output
        coords={
            "lat": ("lat", weights.yc_b.data.reshape(out_shape)[:, 0]),
            "lon": ("lon", weights.xc_b.data.reshape(out_shape)[0, :]),
        },
        # propagate useful information like regridding algorithm
        attrs=weights.attrs,
    )
    
    return xsparse_wgts
    
xsparse_wgts = read_xesmf_weights_file(map_path + map_file)

apply weights using opt_einsum

https://dgasmith.github.io/opt_einsum/

def apply_weights(dataset, weights):
    
    def _apply(da):
        # 🐵 🔧 
        xr.core.duck_array_ops.einsum  = opt_einsum.contract
        
        ans = xr.dot(
            da, 
            weights, 
            # This dimension will be "contracted" 
            # or summmed over after multiplying by the weights
            dims="ncol",
        )
        
        # 🐵 🔧 : restore back to original
        xr.core.duck_array_ops.einsum = np.einsum
        
        return ans
        
    vars_with_ncol = [
        name for name, array  in dataset.variables.items() 
        if "ncol" in array.dims and name not in weights.coords
    ]
    regridded = dataset[vars_with_ncol].map(_apply)
    
    # merge in other variables, but skip those that are already set
    # like lat, lon
    return xr.merge([dataset.drop_vars(regridded.variables), regridded])

apply_weights(psfile, xsparse_wgts.chunk())

Gainzzzz

  1. It'll work with chunked inputs (both weights and da). The core piece is _apply.
  2. You can delete smm.py :O
  3. I wrote this for 1D unstructured -> 2D regridding, but it should work for even structured 2D->2D regridding.
  4. Directly using np.einsum like xr.dot does by default doesn't work so well for chunked weights(bug report) but in my testing i also found opt_einsum to be a lot faster for plain (numpy data x sparse weights)

image

See pydata/xarray#7764 for the upstream issue to avoid the monkey-patch

@aulemahal
Copy link
Collaborator

@charlesgauthier-udm Here's the "parallelize the application" issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants