From 8d30b2071d6869afffe780e3a5a972b0587d72e9 Mon Sep 17 00:00:00 2001 From: arbennett Date: Wed, 4 Jan 2023 18:31:00 -0500 Subject: [PATCH] Add ability to shuffle selectors --- xbatcher/generators.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index d0fcfaf..49b3aef 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -1,6 +1,7 @@ """Classes for iterating through xarray datarrays / datasets in batches.""" import itertools +import random import warnings from operator import itemgetter from typing import Any, Dict, Hashable, Iterator, List, Optional, Sequence, Union @@ -45,6 +46,8 @@ class BatchSchema: preload_batch : bool, optional If ``True``, each batch will be loaded into memory before reshaping / processing, triggering any dask arrays to be computed. + shuffle : bool, optional + If ``True``, batches will be in a shuffled order Notes ----- @@ -59,6 +62,7 @@ def __init__( batch_dims: Optional[Dict[Hashable, int]] = None, concat_input_bins: bool = True, preload_batch: bool = True, + shuffle: bool = False, ): if input_overlap is None: input_overlap = {} @@ -69,6 +73,7 @@ def __init__( self.batch_dims = dict(batch_dims) self.concat_input_dims = concat_input_bins self.preload_batch = preload_batch + self.shuffle = shuffle # Store helpful information based on arguments self._duplicate_batch_dims: Dict[Hashable, int] = { dim: length @@ -98,6 +103,9 @@ def _gen_batch_selectors( """ # Create an iterator that returns an object usable for .isel in xarray patch_selectors = self._gen_patch_selectors(ds) + if self.shuffle: + patch_selectors = list(patch_selectors) + random.shuffle(patch_selectors) # Create the Dict containing batch selectors if self.concat_input_dims: # Combine the patches into batches return self._combine_patches_into_batch(ds, patch_selectors) @@ -364,6 +372,8 @@ class BatchGenerator: preload_batch : bool, optional If ``True``, each batch will be loaded into memory before reshaping / processing, triggering any dask arrays to be computed. + shuffle : bool, optional + If ``True`` batches will be randomly shuffled Yields ------ @@ -379,6 +389,7 @@ def __init__( batch_dims: Dict[Hashable, int] = {}, concat_input_dims: bool = False, preload_batch: bool = True, + shuffle: bool = False, ): self.ds = ds @@ -389,6 +400,7 @@ def __init__( batch_dims=batch_dims, concat_input_bins=concat_input_dims, preload_batch=preload_batch, + shuffle=shuffle, ) @property @@ -411,6 +423,15 @@ def concat_input_dims(self): def preload_batch(self): return self._batch_selectors.preload_batch + def reshuffle(self): + shuffle_idx = list(self._batch_selectors.selectors) + random.shuffle(shuffle_idx) + self._batch_selectors.selectors = { + idx: self._batch_selectors.selectors[shuffled_idx] + for idx, shuffled_idx in zip(self._batch_selectors.selectors, shuffle_idx) + } + + def __iter__(self) -> Iterator[Union[xr.DataArray, xr.Dataset]]: for idx in self._batch_selectors.selectors: yield self[idx]