Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zarr Sink Frame Values #1625

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 181 additions & 15 deletions sources/zarr/large_image_source_zarr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@
self._largeImagePath = str(self._getLargeImagePath())
self._zarr = None
self._editable = False
self._frameValues = None
self._frameAxes = None
self._frameUnits = None
if not os.path.isfile(self._largeImagePath) and '//:' not in self._largeImagePath:
raise TileSourceFileNotFoundError(self._largeImagePath) from None
try:
Expand Down Expand Up @@ -152,6 +155,9 @@
self._imageDescription = None
self._levels = []
self._associatedImages = {}
self._frameValues = None
self._frameAxes = None
self._frameUnits = None

def __del__(self):
if not hasattr(self, '_derivedSource'):
Expand Down Expand Up @@ -218,10 +224,10 @@
tuple that is used to find the maximum size array, preferring ome
arrays, then total pixels, then channels. 'is_ome' is a boolean.
'series' is a list of the found groups and arrays that match the
best criteria. 'axes' and 'channels' are from the best array.
'associated' is a list of all groups and arrays that might be
associated images. These have to be culled for the actual groups
used in the series.
best criteria. 'axes', 'axes_values', 'axes_units', and 'channels'
are from the best array. 'associated' is a list of all groups and
arrays that might be associated images. These have to be culled
for the actual groups used in the series.
"""
attrs = group.attrs.asdict() if group is not None else {}
min_version = packaging.version.Version('0.4')
Expand All @@ -233,9 +239,19 @@
all(packaging.version.Version(m['version']) >= min_version
for m in attrs['multiscales'] if 'version' in m))
channels = None
axes_values = None
axes_units = None
if is_ome:
axes = {axis['name']: idx for idx, axis in enumerate(
attrs['multiscales'][0]['axes'])}
axes_values = {
axis['name']: axis.get('values')
for axis in attrs['multiscales'][0]['axes']
}
axes_units = {
axis['name']: axis.get('unit')
for axis in attrs['multiscales'][0]['axes']
}
if isinstance(attrs['omero'].get('channels'), list):
channels = [channel['label'] for channel in attrs['omero']['channels']]
if all(channel.startswith('Channel ') for channel in channels):
Expand All @@ -254,6 +270,8 @@
results['series'] = [(group, arr)]
results['is_ome'] = is_ome
results['axes'] = axes
results['axes_values'] = axes_values
results['axes_units'] = axes_units
results['channels'] = channels
elif check == results['best']:
results['series'].append((group, arr))
Expand Down Expand Up @@ -333,6 +351,43 @@
except Exception:
pass

def _readFrameValues(self, found, baseArray):
axes_values = found.get('axes_values')
axes_units = found.get('axes_units')
if isinstance(axes_values, dict):
self._frameAxes = [
a for a, i in sorted(self._axes.items(), key=lambda x: x[1])
if axes_values.get(a) is not None
]
self._frameUnits = {k: axes_units.get(k) for k in self.frameAxes if k in axes_units}
frame_values_shape = [baseArray.shape[self._axes[a]] for a in self.frameAxes]
frame_values_shape.append(len(frame_values_shape))
frame_values = np.empty(frame_values_shape, dtype=object)
for axis, values in axes_values.items():
if axis in self.frameAxes:
slicing = [slice(None) for i in range(len(frame_values_shape))]
axis_index = self.frameAxes.index(axis)
slicing[-1] = axis_index
if isinstance(values, list):
# uniform values are written as lists
for i, value in enumerate(values):
slicing[axis_index] = i
frame_values[tuple(slicing)] = value
elif isinstance(values, dict):
# non-uniform values are written as dicts
# mapping values to index permutations
for value, frame_specs in values.items():
if isinstance(value, str):
try:
value = float(value) if '.' in value else int(value)
except Exception:
pass

Check warning on line 384 in sources/zarr/large_image_source_zarr/__init__.py

View check run for this annotation

Codecov / codecov/patch

sources/zarr/large_image_source_zarr/__init__.py#L383-L384

Added lines #L383 - L384 were not covered by tests
for frame_spec in frame_specs:
for a, i in frame_spec.items():
slicing[self.frameAxes.index(a)] = i
frame_values[tuple(slicing)] = value
self._frameValues = frame_values

def _validateZarr(self):
"""
Validate that we can read tiles from the zarr parent group in
Expand Down Expand Up @@ -386,8 +441,10 @@
stride = 1
self._strides = {}
self._axisCounts = {}
for _, k in sorted((-'tzc'.index(k) if k in 'tzc' else 1, k)
for k in self._axes if k not in 'xys'):
for _, k in sorted(
(-self._axes.get(k, 'tzc'.index(k) if k in 'tzc' else -1), k)
for k in self._axes if k not in 'xys'
):
self._strides[k] = stride
self._axisCounts[k] = baseArray.shape[self._axes[k]]
stride *= baseArray.shape[self._axes[k]]
Expand All @@ -396,6 +453,7 @@
self._axisCounts['xy'] = len(self._series)
stride *= len(self._series)
self._framecount = stride
self._readFrameValues(found, baseArray)

def _nonemptyLevelsList(self, frame=0):
"""
Expand Down Expand Up @@ -442,11 +500,32 @@
result = super().getMetadata()
if self._framecount > 1:
result['frames'] = frames = []
if self.frameValues is not None and self.frameAxes is not None:
for i, axis in enumerate(self.frameAxes):
all_frame_values = self.frameValues[..., i]
split = np.split(all_frame_values, all_frame_values.shape[i], axis=i)
values = [a.flat[0] for a in split]
uniform = all(len(np.unique(a)) == 1 for a in split)
result['Value' + axis.upper()] = dict(
values=values,
uniform=uniform,
units=self.frameUnits.get(axis) if self.frameUnits is not None else None,
min=min(values),
max=max(values),
datatype=np.array(values).dtype.name,
)
for idx in range(self._framecount):
frame = {'Frame': idx}
for axis in self._strides:
frame['Index' + axis.upper()] = (
idx // self._strides[axis]) % self._axisCounts[axis]
if self.frameValues is not None and self.frameAxes is not None:
current_frame_slice = tuple(
frame['Index' + axis.upper()] for axis in self.frameAxes
)
current_frame_values = self.frameValues[current_frame_slice]
for i, axis in enumerate(self.frameAxes):
frame['Value' + axis.upper()] = current_frame_values[i]
frames.append(frame)
self._addMetadataFrameInformation(result, getattr(self, '_channels', None))
return result
Expand Down Expand Up @@ -596,7 +675,11 @@
placement = {
'x': x,
'y': y,
**kwargs,
**{k: v for k, v in kwargs.items() if not k.endswith('_value')},
}
frame_values = {
k.replace('_value', ''): v for k, v in kwargs.items()
if k not in placement
}
tile, mask, placement, axes = self._validateNewTile(tile, mask, placement, axes)

Expand All @@ -615,6 +698,21 @@
for i, a in enumerate(axes)
])

if len(frame_values.keys()) > 0:
self.frameAxes = [a for a in axes if a in frame_values]
frames_shape = [new_dims[a] for a in self.frameAxes]
frames_shape.append(len(frames_shape))
if self.frameValues is None:
self.frameValues = np.empty(frames_shape, dtype=object)
elif self.frameValues.shape != frames_shape:
self.frameValues = np.pad(
self.frameValues,
[(0, s - self.frameValues.shape[i]) for i, s in enumerate(frames_shape)],
)
current_frame_slice = tuple(placement.get(a) for a in self.frameAxes)
for i, k in enumerate(self.frameAxes):
self.frameValues[(*current_frame_slice, i)] = frame_values.get(k)

current_arrays = dict(self._zarr.arrays())
if store_path == '0':
# if writing to base data, invalidate generated levels
Expand Down Expand Up @@ -690,6 +788,39 @@
)
self._associatedImages[imageKey] = (group, arr)

def _getAxisInternalMetadata(self, axis_name):
axis_metadata = {'name': axis_name}
if axis_name in ['x', 'y']:
axis_metadata['type'] = 'space'
axis_metadata['unit'] = 'millimeter'
elif axis_name in ['s', 'c']:
axis_metadata['type'] = 'channel'
if self.frameAxes is not None:
axis_index = self.frameAxes.index(axis_name) if axis_name in self.frameAxes else None
if axis_index is not None and self.frameValues is not None:
all_frame_values = self.frameValues[..., axis_index]
split = np.split(
all_frame_values,
all_frame_values.shape[axis_index],
axis=axis_index,
)
uniform = all(len(np.unique(a)) == 1 for a in split)
if uniform:
values = [a.flat[0] for a in split]
else:
values = {}
for index, value in np.ndenumerate(all_frame_values):
if value not in values:
values[value] = []
values[value].append({
axis: index[i] for i, axis in enumerate(self.frameAxes)
})
axis_metadata['values'] = values
unit = self.frameUnits.get(axis_name) if self.frameUnits is not None else None
if unit is not None:
axis_metadata['unit'] = unit
return axis_metadata

def _writeInternalMetadata(self):
self._checkEditable()
with self._threadLock and self._processLock:
Expand All @@ -715,17 +846,11 @@
}
datasets.append(dataset_metadata)
for a in sorted_axes:
axis_metadata = {'name': a}
if a in ['x', 'y']:
axis_metadata['type'] = 'space'
axis_metadata['unit'] = 'millimeter'
elif a in ['s', 'c']:
axis_metadata['type'] = 'channel'
elif a == 't':
if a == 't':
rdefs['defaultT'] = 0
elif a == 'z':
rdefs['defaultZ'] = 0
axes.append(axis_metadata)
axes.append(self._getAxisInternalMetadata(a))
if channel_axis is not None and len(arrays) > 0:
base_array = list(arrays.values())[0]
base_shape = base_array.shape
Expand Down Expand Up @@ -860,6 +985,47 @@
self._checkEditable()
self._channelColors = colors

@property
def frameAxes(self):
return self._frameAxes

@frameAxes.setter
def frameAxes(self, axes):
self._checkEditable()
self._frameAxes = axes

@property
def frameUnits(self):
return self._frameUnits

@frameUnits.setter
def frameUnits(self, units):
self._checkEditable()
if self.frameAxes is None:
err = 'frameAxes must be set first with a list of frame axis names.'
raise ValueError(err)

Check warning on line 1006 in sources/zarr/large_image_source_zarr/__init__.py

View check run for this annotation

Codecov / codecov/patch

sources/zarr/large_image_source_zarr/__init__.py#L1005-L1006

Added lines #L1005 - L1006 were not covered by tests
if not isinstance(units, dict) or not all(
k in self.frameAxes for k in units.keys()
):
err = 'frameUnits must be a dictionary with keys that exist in frameAxes.'
raise ValueError(err)

Check warning on line 1011 in sources/zarr/large_image_source_zarr/__init__.py

View check run for this annotation

Codecov / codecov/patch

sources/zarr/large_image_source_zarr/__init__.py#L1010-L1011

Added lines #L1010 - L1011 were not covered by tests
self._frameUnits = units

@property
def frameValues(self):
return self._frameValues

@frameValues.setter
def frameValues(self, a):
self._checkEditable()
if self.frameAxes is None:
err = 'frameAxes must be set first with a list of frame axis names.'
raise ValueError(err)

Check warning on line 1023 in sources/zarr/large_image_source_zarr/__init__.py

View check run for this annotation

Codecov / codecov/patch

sources/zarr/large_image_source_zarr/__init__.py#L1022-L1023

Added lines #L1022 - L1023 were not covered by tests
if len(a.shape) != len(self.frameAxes) + 1:
err = f'frameValues must have {len(self.frameAxes) + 1} dimensions.'
raise ValueError(err)

Check warning on line 1026 in sources/zarr/large_image_source_zarr/__init__.py

View check run for this annotation

Codecov / codecov/patch

sources/zarr/large_image_source_zarr/__init__.py#L1025-L1026

Added lines #L1025 - L1026 were not covered by tests
self._frameValues = a

def _generateDownsampledLevels(self, resample_method):
self._checkEditable()
current_arrays = dict(self._zarr.arrays())
Expand Down
Loading