From a70017f33fb7c502772d8cf087e1aea65c8aeeb9 Mon Sep 17 00:00:00 2001 From: s-kganz Date: Wed, 23 Apr 2025 16:48:29 -0700 Subject: [PATCH 1/2] add resampling/filtering, add tests, fix warnings from ds.dims in tests --- xbatcher/generators.py | 69 ++++++++++++- xbatcher/tests/test_generators.py | 160 ++++++++++++++++++++++++++++-- 2 files changed, 219 insertions(+), 10 deletions(-) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 2616ab3..e6f5895 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -10,8 +10,9 @@ import numpy as np import xarray as xr -PatchGenerator = Iterator[dict[Hashable, slice]] -BatchSelector = list[dict[Hashable, slice]] +Selector = dict[Hashable, slice] +PatchGenerator = Iterator[Selector] +BatchSelector = list[Selector] BatchSelectorSet = dict[int, BatchSelector] @@ -414,6 +415,18 @@ class BatchGenerator: cache_preprocess: callable, optional A function to apply to batches prior to caching. Note: The caching API is experimental and subject to change. + filter_fn: callable, optional + Function that determines whether a batch is removed. This function should + take a ``Dataset`` or ``DataArray`` as its first argument, a Selector + object as its second argument, and return ``True`` for batches that should be + kept. + resample_fn: callable, optional + Function that determines the relative importance of this batch for + resampling. This function should have the same signature as ``filter_fn``, + but return a float. + resample_n: int + Number of batches to keep after resampling. Must be larger than zero and + less than the number of batches available after filtering. Yields ------ @@ -431,6 +444,9 @@ def __init__( preload_batch: bool = True, cache: dict[str, Any] | None = None, cache_preprocess: Callable | None = None, + filter_fn: Callable[..., bool] | None = None, + resample_fn: Callable[..., float] | None = None, + resample_n: int | None = None, ): if input_overlap is None: input_overlap = {} @@ -439,6 +455,9 @@ def __init__( self.ds = ds self.cache = cache self.cache_preprocess = cache_preprocess + self.filter_fn = filter_fn + self.resample_fn = resample_fn + self.resample_n = resample_n self._batch_selectors: BatchSchema = BatchSchema( ds, @@ -449,6 +468,52 @@ def __init__( preload_batch=preload_batch, ) + # Extract the list of selectors for filtering/resampling. Both steps + # can only remove batches, so if this list gets shorter we know + # we have to re-enumerate the selectors property. + if self._batch_selectors.concat_input_dims: + batches = [s for s in self._batch_selectors.selectors[0]] + else: + batches = [s[0] for s in self._batch_selectors.selectors.values()] + + n_initial_batches = len(batches) + + if self.filter_fn is not None: + batches = [b for b in batches if self.filter_fn(self.ds, b)] + if len(batches) == 0: + warnings.warn('Filtering resulted in no batches.') + + if self.resample_fn is not None: + assert ( + self.resample_n is not None + ), 'resample_n must be provided to resample batches.' + assert len(batches) >= self.resample_n, ( + f'Cannot sample {self.resample_n} slices from this dataset ' + f'when there are {len(batches)} available.' + ) + + weight = np.array([self.resample_fn(self.ds, s) for s in batches]) + assert np.any( + weight > 0 + ), 'Sample weight vector does not have any positive values.' + weight = weight / np.sum(weight) + + batches_to_keep = np.random.choice( + len(batches), self.resample_n, replace=False, p=weight + ) + + batches = [batches[i] for i in batches_to_keep] + + # Re-enumerate the list of batches only if filtering or resampling + # occurred. + if len(batches) < n_initial_batches: + if self._batch_selectors.concat_input_dims: + self._batch_selectors.selectors = {0: [b for b in batches]} + else: + self._batch_selectors.selectors = { + i: [b] for i, b in enumerate(batches) + } + @property def input_dims(self): return self._batch_selectors.input_dims diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index 166648c..59ec960 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -49,6 +49,30 @@ def sample_ds_3d(): return ds +@pytest.fixture(scope='module') +def sample_filter_fn(): + """ + Sample filter function for testing. + """ + + def myfilter(ds, patch): + return ds.isel(patch).bar.mean() > 4.5 + + return myfilter + + +@pytest.fixture(scope='module') +def sample_resample_fn(): + """ + Sample resample function for testing. + """ + + def myresample(ds, patch): + return ds.isel(patch).foo.mean() + + return myresample + + def test_constructor_dataarray(): """ Test that the xarray.DataArray passed to the batch generator is stored @@ -96,7 +120,7 @@ def test_batch_1d(sample_ds_1d, input_size): validate_generator_length(bg) expected_dims = get_batch_dimensions(bg) for n, ds_batch in enumerate(bg): - assert ds_batch.dims['x'] == input_size + assert ds_batch.sizes['x'] == input_size expected_slice = slice(input_size * n, input_size * (n + 1)) ds_batch_expected = sample_ds_1d.isel(x=expected_slice) xr.testing.assert_identical(ds_batch_expected, ds_batch) @@ -146,7 +170,7 @@ def test_batch_1d_no_coordinate(sample_ds_1d, input_size): validate_generator_length(bg) expected_dims = get_batch_dimensions(bg) for n, ds_batch in enumerate(bg): - assert ds_batch.dims['x'] == input_size + assert ds_batch.sizes['x'] == input_size expected_slice = slice(input_size * n, input_size * (n + 1)) ds_batch_expected = ds_dropped.isel(x=expected_slice) xr.testing.assert_identical(ds_batch_expected, ds_batch) @@ -187,7 +211,7 @@ def test_batch_1d_overlap(sample_ds_1d, input_overlap): expected_dims = get_batch_dimensions(bg) stride = input_size - input_overlap for n, ds_batch in enumerate(bg): - assert ds_batch.dims['x'] == input_size + assert ds_batch.sizes['x'] == input_size expected_slice = slice(stride * n, stride * n + input_size) ds_batch_expected = sample_ds_1d.isel(x=expected_slice) xr.testing.assert_identical(ds_batch_expected, ds_batch) @@ -204,11 +228,11 @@ def test_batch_3d_1d_input(sample_ds_3d, input_size): validate_generator_length(bg) expected_dims = get_batch_dimensions(bg) for n, ds_batch in enumerate(bg): - assert ds_batch.dims['x'] == input_size + assert ds_batch.sizes['x'] == input_size # time and y should be collapsed into batch dimension assert ( - ds_batch.dims['sample'] - == sample_ds_3d.dims['y'] * sample_ds_3d.dims['time'] + ds_batch.sizes['sample'] + == sample_ds_3d.sizes['y'] * sample_ds_3d.sizes['time'] ) expected_slice = slice(input_size * n, input_size * (n + 1)) ds_batch_expected = ( @@ -279,8 +303,8 @@ def test_batch_3d_2d_input(sample_ds_3d, input_size): yn, xn = np.unravel_index( n, ( - (sample_ds_3d.dims['y'] // input_size), - (sample_ds_3d.dims['x'] // x_input_size), + (sample_ds_3d.sizes['y'] // input_size), + (sample_ds_3d.sizes['x'] // x_input_size), ), ) expected_xslice = slice(x_input_size * xn, x_input_size * (xn + 1)) @@ -427,3 +451,123 @@ def preproc(ds): ds_cache = bg[1] xr.testing.assert_equal(ds_no_cache, ds_cache) xr.testing.assert_identical(ds_no_cache, ds_cache) + + +def test_filter_1d(sample_ds_1d, sample_filter_fn): + bg = BatchGenerator(sample_ds_1d, input_dims={'x': 5}) + + bg_filter = BatchGenerator( + sample_ds_1d, input_dims={'x': 5}, filter_fn=sample_filter_fn + ) + + assert len(bg_filter) < len(bg) + + for batch in bg_filter: + assert batch.bar.mean() > 4.5 + + +def test_filter_3d(sample_ds_3d, sample_filter_fn): + bg = BatchGenerator(sample_ds_3d, input_dims={'x': 5, 'y': 5, 'time': 5}) + + bg_filter = BatchGenerator( + sample_ds_3d, input_dims={'x': 5, 'y': 5, 'time': 5}, filter_fn=sample_filter_fn + ) + + assert len(bg_filter) < len(bg) + + for batch in bg_filter: + assert batch.bar.mean() > 4.5 + + +def test_filter_3d_concat(sample_ds_3d, sample_filter_fn): + bg = BatchGenerator( + sample_ds_3d, input_dims={'x': 5, 'y': 5, 'time': 5}, concat_input_dims=True + ) + + bg_filter = BatchGenerator( + sample_ds_3d, + input_dims={'x': 5, 'y': 5, 'time': 5}, + filter_fn=sample_filter_fn, + concat_input_dims=True, + ) + + assert bg_filter[0].sizes['input_batch'] < bg[0].sizes['input_batch'] + + assert (bg_filter[0].bar.mean(dim=['x_input', 'y_input', 'time_input']) > 4.5).all() + + +@pytest.mark.parametrize('n', [5, 10]) +def test_resample_1d(sample_ds_1d, sample_resample_fn, n): + bg = BatchGenerator( + sample_ds_1d, input_dims={'x': 5}, resample_fn=sample_resample_fn, resample_n=n + ) + assert len(bg) == n + + +@pytest.mark.parametrize('n', [10, 50, 100]) +def test_resample_3d(sample_ds_3d, sample_resample_fn, n): + bg = BatchGenerator( + sample_ds_3d, + input_dims={'x': 5, 'y': 5, 'time': 5}, + resample_fn=sample_resample_fn, + resample_n=n, + ) + assert len(bg) == n + + +def test_filter_prevents_resample(sample_ds_3d, sample_resample_fn): + def strict_filter(*args): + return False + + with pytest.raises(AssertionError, match='Cannot sample 1000 slices'): + BatchGenerator( + sample_ds_3d, + input_dims={'x': 5, 'y': 5, 'time': 5}, + resample_fn=sample_resample_fn, + resample_n=1000, + ) + + +def test_error_missing_resample_n(sample_ds_3d): + with pytest.raises(AssertionError, match='resample_n must be provided'): + BatchGenerator( + sample_ds_3d, + input_dims={'x': 5, 'y': 5, 'time': 5}, + resample_fn=sample_resample_fn, + resample_n=None, + ) + + +def test_error_large_resample_n(sample_ds_3d, sample_resample_fn): + with pytest.raises(AssertionError, match='Cannot sample 999999999 slices'): + BatchGenerator( + sample_ds_3d, + input_dims={'x': 5, 'y': 5, 'time': 5}, + resample_fn=sample_resample_fn, + resample_n=999999999, + ) + + +def test_error_all_zero_resample_weight(sample_ds_3d): + def zero(*args): + return 0 + + with pytest.raises(AssertionError, match='Sample weight vector'): + BatchGenerator( + sample_ds_3d, + input_dims={'x': 5, 'y': 5, 'time': 5}, + resample_fn=zero, + resample_n=100, + ) + + +def test_warning_empty_filter(sample_ds_3d): + def strict_filter(*args): + return False + + with pytest.warns(UserWarning, match='no batches'): + BatchGenerator( + sample_ds_3d, + input_dims={'x': 5, 'y': 5, 'time': 5}, + filter_fn=strict_filter, + ) From 2e721fcd06a9c560207718008c911061ef6c5512 Mon Sep 17 00:00:00 2001 From: s-kganz Date: Fri, 25 Apr 2025 14:43:15 -0700 Subject: [PATCH 2/2] notebook on filtering/resampling --- doc/user-guide/filtering-and-resampling.ipynb | 685 ++++++++++++++++++ 1 file changed, 685 insertions(+) create mode 100644 doc/user-guide/filtering-and-resampling.ipynb diff --git a/doc/user-guide/filtering-and-resampling.ipynb b/doc/user-guide/filtering-and-resampling.ipynb new file mode 100644 index 0000000..9fbe14d --- /dev/null +++ b/doc/user-guide/filtering-and-resampling.ipynb @@ -0,0 +1,685 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "b314e777-7ffb-4e62-b4c5-ce8a785c5181", + "metadata": {}, + "source": [ + "# Filtering and resampling Xarray datasets with xbatcher\n", + "\n", + "There are many cases in machine learning where we want to discard invalid observations or modify the distribution of a target variable. This notebook demonstrates how `BatchGenerators` can be used to make filtered or resampled datasets by passing functions that identify usable data or assign a sample weight to each patch." + ] + }, + { + "cell_type": "markdown", + "id": "7158f5f3-42f5-4dcd-87ee-045d9d0e85f5", + "metadata": {}, + "source": [ + "### Libraries and toy data" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "5d912ff0-d808-4704-8dea-b9e1b5a53bf1", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import xarray as xr\n", + "\n", + "import xbatcher as xb" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "7fb892c1-50fd-48c8-8567-b150946b53c9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset> Size: 31MB\n",
+       "Dimensions:  (lat: 25, time: 2920, lon: 53)\n",
+       "Coordinates:\n",
+       "  * lat      (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n",
+       "  * lon      (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n",
+       "  * time     (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n",
+       "Data variables:\n",
+       "    air      (time, lat, lon) float64 31MB ...\n",
+       "Attributes:\n",
+       "    Conventions:  COARDS\n",
+       "    title:        4x daily NMC reanalysis (1948)\n",
+       "    description:  Data is from NMC initialized reanalysis\\n(4x/day).  These a...\n",
+       "    platform:     Model\n",
+       "    references:   http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...
" + ], + "text/plain": [ + " Size: 31MB\n", + "Dimensions: (lat: 25, time: 2920, lon: 53)\n", + "Coordinates:\n", + " * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n", + " * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n", + " * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n", + "Data variables:\n", + " air (time, lat, lon) float64 31MB ...\n", + "Attributes:\n", + " Conventions: COARDS\n", + " title: 4x daily NMC reanalysis (1948)\n", + " description: Data is from NMC initialized reanalysis\\n(4x/day). These a...\n", + " platform: Model\n", + " references: http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly..." + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds = xr.tutorial.open_dataset('air_temperature')\n", + "ds" + ] + }, + { + "cell_type": "markdown", + "id": "85fbfe7b-c006-4052-b276-41e2354263a4", + "metadata": {}, + "source": [ + "### Filtering\n", + "\n", + "Here we add a QA variable to the air temperature dataset. Suppose that 1% of the time there is an instrument failure, and we do not want any cells with the QA flag set to go into a model." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "3e448fe0-af4c-49e9-be8d-529b29cb12fc", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "ds['qa'] = (('time', 'lat', 'lon'), np.random.rand(*ds.air.shape) < 0.01)\n", + "ds['qa'].isel(time=0).plot()" + ] + }, + { + "cell_type": "markdown", + "id": "a9e888d1-4c4d-4ec7-aa47-8ec7c17c3b63", + "metadata": {}, + "source": [ + "Define a small function to determine which patches to keep. The function should take the underlying dataset as its first argument, and a dictionary of slice objects as the second argument. Each dictionary corresponds to one batch from the `BatchGenerator`. Batches for which the function returns True are retained in the `BatchGenerator`." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "000a8888-aa69-4bc0-b932-a7879f43fa85", + "metadata": {}, + "outputs": [], + "source": [ + "def myfilter(ds, batch):\n", + " return (ds.isel(**batch).qa == 0).all()" + ] + }, + { + "cell_type": "markdown", + "id": "e615808d-5c5b-46c5-955e-bbe95f01aa35", + "metadata": {}, + "source": [ + "Now we pass the filter function to the batch generator and verify that none of the anomalous pixels make it into resulting batches." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "9afa1d84-8991-4572-a15e-f67b278b169c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Original generator: 29200 batches\n", + "Filtered generator: 8344 batches\n" + ] + } + ], + "source": [ + "bgen_original = xb.BatchGenerator(ds, {'lat': 5, 'lon': 5, 'time': 5})\n", + "\n", + "bgen_filtered = xb.BatchGenerator(\n", + " ds, {'lat': 5, 'lon': 5, 'time': 5}, filter_fn=myfilter\n", + ")\n", + "\n", + "print('Original generator:', len(bgen_original), 'batches')\n", + "print('Filtered generator:', len(bgen_filtered), 'batches')\n", + "\n", + "for batch in bgen_filtered:\n", + " assert (batch.qa == 0).all()" + ] + }, + { + "cell_type": "markdown", + "id": "32b49b74-3e46-435e-84d6-0c5b0e8610cf", + "metadata": {}, + "source": [ + "### Resampling\n", + "\n", + "Now we show how to resample a `BatchGenerator`. Note that this approach only supports *undersampling*. That is, you can only remove batches, not duplicate them. Now our task is to define a function with the same signature as the filter, but this time returning a non-negative float that indicates the relative sample weight of this patch. This functionality uses `np.random.choice` to select batches, so use `np.random.seed` to ensure reproducibility.\n", + "\n", + "Suppose we want to sample the dataset to emphasize batches with low air temperature. One option is to return a higher sample weight for batches with mean air temperature below a certain threshold." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "26ffc331-558b-49fb-a7fe-2da91bf8276b", + "metadata": {}, + "outputs": [], + "source": [ + "threshold = 270\n", + "\n", + "\n", + "def myresample(ds, batch):\n", + " window_mean = ds.isel(**batch).air.mean()\n", + " if window_mean > threshold:\n", + " return 1\n", + " else:\n", + " return 4\n", + "\n", + "\n", + "bgen_resampled = xb.BatchGenerator(\n", + " ds, {'lat': 5, 'lon': 5, 'time': 5}, resample_fn=myresample, resample_n=5_000\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "6c6b1ec5-8e63-47b0-b8c3-82e188e0109f", + "metadata": {}, + "source": [ + "Now, we can compare the distribution of batch mean air temperature between the resampled and original generator." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "e3ef2917-080a-40f3-8b47-97794b112cfa", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "original_mean = np.array([batch.air.mean() for batch in bgen_original])\n", + "filtered_mean = np.array([batch.air.mean() for batch in bgen_resampled])\n", + "\n", + "fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True)\n", + "\n", + "line = dict(x=threshold, color='black', linestyle='dashed')\n", + "\n", + "ax1.hist(original_mean)\n", + "ax1.axvline(**line)\n", + "ax1.set_title('original')\n", + "ax2.hist(filtered_mean)\n", + "ax2.axvline(**line)\n", + "ax2.set_title('resampled')\n", + "plt.tight_layout()\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.17" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}