Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assorted improvements for TiledDataset #402

Merged
merged 15 commits into from
Aug 26, 2024
Merged
4 changes: 4 additions & 0 deletions changelog/402.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Add various features for easier inspection of `TiledDataset`:
- `__repr__` method to output basic dataset info;
- `tiles_shape` property to access data array shape for each individual tile;
- `slice_tiles()` method to apply the same slice to all datasets.
9 changes: 9 additions & 0 deletions dkist/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,15 @@ def simple_tiled_dataset(dataset):
return TiledDataset(dataset_array, dataset.meta["inventory"])


@pytest.fixture
def large_tiled_dataset(tmp_path_factory):
vbidir = tmp_path_factory.mktemp("data")
with gzip.open(Path(rootdir) / "large_vbi.asdf.gz", mode="rb") as gfo:
with open(vbidir / "test_vbi.asdf", mode="wb") as afo:
afo.write(gfo.read())
return load_dataset(vbidir / "test_vbi.asdf")


@pytest.fixture
def small_visp_dataset():
"""
Expand Down
Binary file added dkist/data/test/large_vbi.asdf.gz
Binary file not shown.
16 changes: 16 additions & 0 deletions dkist/dataset/tests/test_tiled_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ def test_tiled_dataset_slice(simple_tiled_dataset, aslice):
assert np.all(simple_tiled_dataset[aslice] == simple_tiled_dataset._data[aslice])


@pytest.mark.parametrize("aslice", [np.s_[0, :100, 100:200]])
def test_tiled_dataset_slice_tiles(large_tiled_dataset, aslice):
sliced = large_tiled_dataset.slice_tiles[aslice]
for i, tile in enumerate(sliced.flat):
# This will throw an AttributeError if you do tile.shape and I don't know why
assert tile.data.shape == (100, 100)


def test_tiled_dataset_headers(simple_tiled_dataset, dataset):
assert len(simple_tiled_dataset.combined_headers) == len(dataset.meta["headers"]) * 4
assert simple_tiled_dataset.combined_headers.colnames == dataset.meta["headers"].colnames
Expand Down Expand Up @@ -75,3 +83,11 @@ def test_tileddataset_plot(share_zscale):
fig = plt.figure(figsize=(600, 800))
ds.plot(0, share_zscale=share_zscale)
return plt.gcf()

def test_repr(simple_tiled_dataset):
r = repr(simple_tiled_dataset)
assert str(simple_tiled_dataset[0, 0].data) in r


def test_tiles_shape(simple_tiled_dataset):
assert simple_tiled_dataset.tiles_shape == [[tile.data.shape for tile in row] for row in simple_tiled_dataset]
38 changes: 38 additions & 0 deletions dkist/dataset/tiled_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
but not representable in a single NDCube derived object as the array data are
not contiguous in the spatial dimensions (due to overlaps and offsets).
"""
from textwrap import dedent
from collections.abc import Collection

import matplotlib.pyplot as plt
Expand All @@ -13,10 +14,26 @@
from astropy.table import vstack

from .dataset import Dataset
from .utils import dataset_info_str

__all__ = ["TiledDataset"]


class TiledDatasetSlicer:
"""
Basic class to provide the slicing
"""
def __init__(self, data, inventory):
self.data = data
self.inventory = inventory

def __getitem__(self, slice_):
new_data = []
for tile in self.data.flat:
new_data.append(tile[slice_])
return TiledDataset(np.array(new_data).reshape(self.data.shape), self.inventory)


class TiledDataset(Collection):
"""
Holds a grid of `.Dataset` objects.
Expand Down Expand Up @@ -125,6 +142,13 @@ def shape(self):
"""
return self._data.shape

@property
def tiles_shape(self):
"""
The shape of each individual tile in the TiledDataset.
"""
return [[tile.data.shape for tile in row] for row in self]

def plot(self, slice_index: int, share_zscale=False, **kwargs):
vmin, vmax = np.inf, 0
fig = plt.figure()
Expand All @@ -151,4 +175,18 @@ def plot(self, slice_index: int, share_zscale=False, **kwargs):
fig.suptitle(f"{self.inventory['instrumentName']} Dataset ({self.inventory['datasetId']}) at time {timestamp} (slice={slice_index})", y=0.95)
return fig

@property
def slice_tiles(self):
return TiledDatasetSlicer(self._data, self.inventory)

# TODO: def regrid()

def __repr__(self):
"""
Overload the NDData repr because it does not play nice with the dask delayed io.
"""
prefix = object.__repr__(self)
return dedent(f"{prefix}\n{self.__str__()}")

def __str__(self):
return dataset_info_str(self)
15 changes: 13 additions & 2 deletions dkist/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,26 @@


def dataset_info_str(ds):
# Check for an attribute that only appears on TiledDataset
# Not using isinstance to avoid circular import
is_tiled = hasattr(ds, "combined_headers")
dstype = type(ds).__name__
if is_tiled:
tile_shape = ds.shape
ds = ds[0, 0]
wcs = ds.wcs.low_level_wcs

# Pixel dimensions table

instr = ds.meta.get("instrument_name", "")
instr = ds.inventory.get("instrument", "")
if instr:
SolarDrew marked this conversation as resolved.
Show resolved Hide resolved
instr += " "

s = f"This {instr}Dataset has {wcs.pixel_n_dim} pixel and {wcs.world_n_dim} world dimensions\n\n"
if is_tiled:
s = f"This {dstype} consists of an array of {tile_shape} Dataset objects\n\n"
s += f"Each {instr}Dataset has {wcs.pixel_n_dim} pixel and {wcs.world_n_dim} world dimensions\n\n"
else:
s = f"This {instr}Dataset has {wcs.pixel_n_dim} pixel and {wcs.world_n_dim} world dimensions\n\n"
s += f"{ds.data}\n\n"

array_shape = wcs.array_shape or (0,)
Expand Down