From 4b1dd13871419e19e8dc6ad4e1ce1b0ca5fc482d Mon Sep 17 00:00:00 2001 From: Anne Haley Date: Wed, 4 Sep 2024 15:31:04 +0000 Subject: [PATCH 1/5] Add basic APIs for specifying frame values, add test --- .../zarr/large_image_source_zarr/__init__.py | 97 +++++++++++++++- test/test_sink.py | 106 ++++++++++++++++++ 2 files changed, 200 insertions(+), 3 deletions(-) diff --git a/sources/zarr/large_image_source_zarr/__init__.py b/sources/zarr/large_image_source_zarr/__init__.py index 1296f1287..54dd569bf 100644 --- a/sources/zarr/large_image_source_zarr/__init__.py +++ b/sources/zarr/large_image_source_zarr/__init__.py @@ -98,6 +98,9 @@ def _initOpen(self, **kwargs): self._largeImagePath = str(self._getLargeImagePath()) self._zarr = None self._editable = False + self._frameValues = None + self._frameAxes = None + self._frameUnits = None if not os.path.isfile(self._largeImagePath) and '//:' not in self._largeImagePath: raise TileSourceFileNotFoundError(self._largeImagePath) from None try: @@ -152,6 +155,9 @@ def _initNew(self, path, **kwargs): self._imageDescription = None self._levels = [] self._associatedImages = {} + self._frameValues = None + self._frameAxes = None + self._frameUnits = None def __del__(self): if not hasattr(self, '_derivedSource'): @@ -386,8 +392,10 @@ def _validateZarr(self): stride = 1 self._strides = {} self._axisCounts = {} - for _, k in sorted((-'tzc'.index(k) if k in 'tzc' else 1, k) - for k in self._axes if k not in 'xys'): + for _, k in sorted( + (-self._axes.get(k, 'tzc'.index(k) if k in 'tzc' else -1), k) + for k in self._axes if k not in 'xys' + ): self._strides[k] = stride self._axisCounts[k] = baseArray.shape[self._axes[k]] stride *= baseArray.shape[self._axes[k]] @@ -442,11 +450,32 @@ def getMetadata(self): result = super().getMetadata() if self._framecount > 1: result['frames'] = frames = [] + if self.frameValues is not None and self.frameAxes is not None: + for i, axis in enumerate(self.frameAxes): + all_frame_values = self.frameValues[..., i] + split = np.split(all_frame_values, all_frame_values.shape[i], axis=i) + values = [a.flat[0] for a in split] + uniform = all(len(np.unique(a)) == 1 for a in split) + result['Value' + axis.upper()] = dict( + values=values, + uniform=uniform, + units=self.frameUnits.get(axis) if self.frameUnits is not None else None, + min=min(values), + max=max(values), + datatype=np.array(values).dtype.name + ) for idx in range(self._framecount): frame = {'Frame': idx} for axis in self._strides: frame['Index' + axis.upper()] = ( idx // self._strides[axis]) % self._axisCounts[axis] + if self.frameValues is not None and self.frameAxes is not None: + current_frame_slice = tuple( + frame['Index' + axis.upper()] for axis in self.frameAxes + ) + current_frame_values = self.frameValues[current_frame_slice] + for i, axis in enumerate(self.frameAxes): + frame['Value' + axis.upper()] = current_frame_values[i] frames.append(frame) self._addMetadataFrameInformation(result, getattr(self, '_channels', None)) return result @@ -596,7 +625,11 @@ def addTile(self, tile, x=0, y=0, mask=None, axes=None, **kwargs): placement = { 'x': x, 'y': y, - **kwargs, + **{k: v for k, v in kwargs.items() if not k.endswith('_value')}, + } + frame_values = { + k.replace('_value', ''): v for k, v in kwargs.items() + if k not in placement } tile, mask, placement, axes = self._validateNewTile(tile, mask, placement, axes) @@ -615,6 +648,21 @@ def addTile(self, tile, x=0, y=0, mask=None, axes=None, **kwargs): for i, a in enumerate(axes) ]) + if len(frame_values.keys()) > 0: + self.frameAxes = [a for a in axes if a in frame_values] + frames_shape = [new_dims[a] for a in self.frameAxes] + frames_shape.append(len(frames_shape)) + if self.frameValues is None: + self.frameValues = np.empty(frames_shape, dtype=object) + elif self.frameValues.shape != frames_shape: + self.frameValues = np.pad( + self.frameValues, + [(0, s - self.frameValues.shape[i]) for i, s in enumerate(frames_shape)] + ) + current_frame_slice = tuple(placement.get(a) for a in self.frameAxes) + for i, k in enumerate(self.frameAxes): + self.frameValues[(*current_frame_slice, i)] = frame_values.get(k) + current_arrays = dict(self._zarr.arrays()) if store_path == '0': # if writing to base data, invalidate generated levels @@ -725,6 +773,9 @@ def _writeInternalMetadata(self): rdefs['defaultT'] = 0 elif a == 'z': rdefs['defaultZ'] = 0 + unit = self.frameUnits.get(a) if self.frameUnits is not None else None + if unit is not None: + axis_metadata['unit'] = unit axes.append(axis_metadata) if channel_axis is not None and len(arrays) > 0: base_array = list(arrays.values())[0] @@ -860,6 +911,46 @@ def channelColors(self, colors): self._checkEditable() self._channelColors = colors + @property + def frameAxes(self): + return self._frameAxes + + @frameAxes.setter + def frameAxes(self, axes): + self._checkEditable() + self._frameAxes = axes + + @property + def frameUnits(self): + return self._frameUnits + + @frameUnits.setter + def frameUnits(self, units): + self._checkEditable() + if self.frameAxes is None: + err = 'frameAxes must be set first with a list of frame axis names.' + raise ValueError + if not isinstance(units, dict) or not all( + k in self.frameAxes for k in units.keys() + ): + err = 'frameUnits must be a dictionary with keys that exist in frameAxes.' + self._frameUnits = units + + @property + def frameValues(self): + return self._frameValues + + @frameValues.setter + def frameValues(self, a): + self._checkEditable() + if self.frameAxes is None: + err = 'frameAxes must be set first with a list of frame axis names.' + raise ValueError + if len(a.shape) != len(self.frameAxes) + 1: + err = f'frameValues must have {len(self.frameAxes) + 1} dimensions.' + raise ValueError(err) + self._frameValues = a + def _generateDownsampledLevels(self, resample_method): self._checkEditable() current_arrays = dict(self._zarr.arrays()) diff --git a/test/test_sink.py b/test/test_sink.py index df6e7c6ca..3bd5dd74b 100644 --- a/test/test_sink.py +++ b/test/test_sink.py @@ -535,3 +535,109 @@ def testConcurrency(tmp_path): assert data is not None assert data.shape == seed_data.shape assert np.allclose(data, seed_data) + + +def compare_metadata(actual, expected): + assert type(actual) is type(expected) + if isinstance(actual, list): + for i, v in enumerate(actual): + compare_metadata(v, expected[i]) + elif isinstance(actual, dict): + assert len(actual.keys()) == len(expected.keys()) + for k, v in actual.items(): + compare_metadata(v, expected[k]) + else: + assert actual == expected + + +@pytest.mark.parametrize('use_add_tile_args', [True, False]) +def testFrameValuesAddTile(use_add_tile_args, tmp_path): + output_file = tmp_path / 'test.db' + sink = large_image_source_zarr.new() + + frame_shape = (300, 400, 3) + expected = dict( + z=dict(values=[2, 4, 6, 8], uniform=True, units='m', stride=9, dtype='int64'), + t=dict(values=[10.0, 20.0, 30.0], uniform=False, units='ms', stride=3, dtype='float64'), + c=dict(values=['r', 'g', 'b'], uniform=True, units='channel', stride=1, dtype='str32'), + ) + expected_metadata = dict( + levels=1, + sizeY=frame_shape[0], + sizeX=frame_shape[1], + bandCount=frame_shape[2], + frames=[], + tileWidth=512, + tileHeight=512, + magnification=None, + mm_x=0, + mm_y=0, + dtype='float64', + channels=[f'Band {c + 1}' for c in range(len(expected['c']['values']))], + channelmap={f'Band {c + 1}': c for c in range(len(expected['c']['values']))}, + IndexRange={ + f'Index{k.upper()}': len(v['values']) for k, v in expected.items() + }, + IndexStride={ + f'Index{k.upper()}': v['stride'] for k, v in expected.items() + }, + **{ + f'Value{k.upper()}': dict( + values=v['values'], + units=v['units'], + uniform=v['uniform'], + min=min(v['values']), + max=max(v['values']), + datatype=v['dtype'], + ) for k, v in expected.items() + } + ) + + sink.frameAxes = list(expected.keys()) + sink.frameUnits = { + k: v['units'] for k, v in expected.items() + } + frame_values_shape = [ + *[len(v['values']) for v in expected.values()], + len(expected) + ] + frame_values = np.empty(frame_values_shape, dtype=object) + + frame = 0 + index = 0 + for z, z_value in enumerate(expected['z']['values']): + for t, t_value in enumerate(expected['t']['values']): + if not expected['t']['uniform']: + t_value += 0.01 * z + for c, c_value in enumerate(expected['c']['values']): + add_tile_args = dict(z=z, t=t, c=c, axes=['z', 't', 'c', 'y', 'x', 's']) + if use_add_tile_args: + add_tile_args.update(z_value=z_value, t_value=t_value, c_value=c_value) + else: + frame_values[z, t, c] = [z_value, t_value, c_value] + random_tile = np.random.random(frame_shape) + sink.addTile(random_tile, 0, 0, **add_tile_args) + expected_metadata['frames'].append( + dict( + Frame=frame, + Index=index, + IndexZ=z, + ValueZ=z_value, + IndexT=t, + ValueT=t_value, + IndexC=c, + ValueC=c_value, + Channel=f'Band {c + 1}', + ) + ) + frame += 1 + index += 1 + + if not use_add_tile_args: + sink.frameValues = frame_values + + compare_metadata(dict(sink.getMetadata()), expected_metadata) + + # sink.write(output_file) + # written = large_image_source_zarr.open(output_file) + # compare_metadata(dict(written.getMetadata()), expected_metadata) From 68666f809ed264a70a14fef2eaef7af1f7fe55ff Mon Sep 17 00:00:00 2001 From: Anne Haley Date: Wed, 4 Sep 2024 18:47:44 +0000 Subject: [PATCH 2/5] write frame values to internal metadata and reconstruct on read --- .../zarr/large_image_source_zarr/__init__.py | 69 +++++++++++++++++-- test/test_sink.py | 11 ++- 2 files changed, 68 insertions(+), 12 deletions(-) diff --git a/sources/zarr/large_image_source_zarr/__init__.py b/sources/zarr/large_image_source_zarr/__init__.py index 54dd569bf..b717a0524 100644 --- a/sources/zarr/large_image_source_zarr/__init__.py +++ b/sources/zarr/large_image_source_zarr/__init__.py @@ -224,7 +224,7 @@ def _scanZarrArray(self, group, arr, results): tuple that is used to find the maximum size array, preferring ome arrays, then total pixels, then channels. 'is_ome' is a boolean. 'series' is a list of the found groups and arrays that match the - best criteria. 'axes' and 'channels' are from the best array. + best criteria. 'axes', 'axes_values', 'axes_units', and 'channels' are from the best array. 'associated' is a list of all groups and arrays that might be associated images. These have to be culled for the actual groups used in the series. @@ -239,9 +239,13 @@ def _scanZarrArray(self, group, arr, results): all(packaging.version.Version(m['version']) >= min_version for m in attrs['multiscales'] if 'version' in m)) channels = None + axes_values = None + axes_units = None if is_ome: axes = {axis['name']: idx for idx, axis in enumerate( attrs['multiscales'][0]['axes'])} + axes_values = {axis['name']: axis.get('values') for axis in attrs['multiscales'][0]['axes']} + axes_units = {axis['name']: axis.get('unit') for axis in attrs['multiscales'][0]['axes']} if isinstance(attrs['omero'].get('channels'), list): channels = [channel['label'] for channel in attrs['omero']['channels']] if all(channel.startswith('Channel ') for channel in channels): @@ -260,6 +264,8 @@ def _scanZarrArray(self, group, arr, results): results['series'] = [(group, arr)] results['is_ome'] = is_ome results['axes'] = axes + results['axes_values'] = axes_values + results['axes_units'] = axes_units results['channels'] = channels elif check == results['best']: results['series'].append((group, arr)) @@ -404,6 +410,40 @@ def _validateZarr(self): self._axisCounts['xy'] = len(self._series) stride *= len(self._series) self._framecount = stride + axes_values = found.get('axes_values') + axes_units = found.get('axes_units') + if isinstance(axes_values, dict): + self._frameAxes = [ + a for a, i in sorted(self._axes.items(), key=lambda x: x[1]) + if axes_values.get(a) is not None + ] + self._frameUnits = {k: axes_units.get(k) for k in self.frameAxes if k in axes_units} + frame_values_shape = [baseArray.shape[self._axes[a]] for a in self.frameAxes] + frame_values_shape.append(len(frame_values_shape)) + frame_values = np.empty(frame_values_shape, dtype=object) + for axis, values in axes_values.items(): + if axis in self.frameAxes: + slicing = [slice(None) for i in range(len(frame_values_shape))] + axis_index = self.frameAxes.index(axis) + slicing[-1] = axis_index + if isinstance(values, list): + # uniform values are written as lists + for i, value in enumerate(values): + slicing[axis_index] = i + frame_values[tuple(slicing)] = value + elif isinstance(values, dict): + # non-uniform values are written as dicts mapping values to index permutations + for value, frame_specs in values.items(): + if isinstance(value, str): + try: + value = float(value) if '.' in value else int(value) + except: + pass + for frame_spec in frame_specs: + for a, i in frame_spec.items(): + slicing[self.frameAxes.index(a)] = i + frame_values[tuple(slicing)] = value + self._frameValues = frame_values def _nonemptyLevelsList(self, frame=0): """ @@ -773,9 +813,26 @@ def _writeInternalMetadata(self): rdefs['defaultT'] = 0 elif a == 'z': rdefs['defaultZ'] = 0 - unit = self.frameUnits.get(a) if self.frameUnits is not None else None - if unit is not None: - axis_metadata['unit'] = unit + if self.frameAxes is not None: + frame_axis_index = self.frameAxes.index(a) if a in self.frameAxes else None + if frame_axis_index is not None and self.frameValues is not None: + all_frame_values = self.frameValues[..., frame_axis_index] + split = np.split(all_frame_values, all_frame_values.shape[frame_axis_index], axis=frame_axis_index) + uniform = all(len(np.unique(a)) == 1 for a in split) + if uniform: + values = [a.flat[0] for a in split] + else: + values = {} + for index, value in np.ndenumerate(all_frame_values): + if value not in values: + values[value] = [] + values[value].append({ + axis: index[i] for i, axis in enumerate(self.frameAxes) + }) + axis_metadata['values'] = values + unit = self.frameUnits.get(a) if self.frameUnits is not None else None + if unit is not None: + axis_metadata['unit'] = unit axes.append(axis_metadata) if channel_axis is not None and len(arrays) > 0: base_array = list(arrays.values())[0] @@ -919,7 +976,7 @@ def frameAxes(self): def frameAxes(self, axes): self._checkEditable() self._frameAxes = axes - + @property def frameUnits(self): return self._frameUnits @@ -950,7 +1007,7 @@ def frameValues(self, a): err = f'frameValues must have {len(self.frameAxes) + 1} dimensions.' raise ValueError(err) self._frameValues = a - + def _generateDownsampledLevels(self, resample_method): self._checkEditable() current_arrays = dict(self._zarr.arrays()) diff --git a/test/test_sink.py b/test/test_sink.py index 3bd5dd74b..28082891a 100644 --- a/test/test_sink.py +++ b/test/test_sink.py @@ -557,8 +557,8 @@ def testFrameValuesAddTile(use_add_tile_args, tmp_path): frame_shape = (300, 400, 3) expected = dict( - z=dict(values=[2, 4, 6, 8], uniform=True, units='m', stride=9, dtype='int64'), - t=dict(values=[10.0, 20.0, 30.0], uniform=False, units='ms', stride=3, dtype='float64'), + z=dict(values=[2, 4, 6, 8], uniform=True, units='meter', stride=9, dtype='int64'), + t=dict(values=[10.0, 20.0, 30.0], uniform=False, units='millisecond', stride=3, dtype='float64'), c=dict(values=['r', 'g', 'b'], uniform=True, units='channel', stride=1, dtype='str32'), ) expected_metadata = dict( @@ -635,9 +635,8 @@ def testFrameValuesAddTile(use_add_tile_args, tmp_path): if not use_add_tile_args: sink.frameValues = frame_values - compare_metadata(dict(sink.getMetadata()), expected_metadata) - # sink.write(output_file) - # written = large_image_source_zarr.open(output_file) - # compare_metadata(dict(written.getMetadata()), expected_metadata) + sink.write(output_file) + written = large_image_source_zarr.open(output_file) + compare_metadata(dict(written.getMetadata()), expected_metadata) From d02dd55b50748137830e32247cae660bb1e35c7d Mon Sep 17 00:00:00 2001 From: Anne Haley Date: Wed, 4 Sep 2024 18:57:33 +0000 Subject: [PATCH 3/5] refactor test and add small test --- test/test_sink.py | 131 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 93 insertions(+), 38 deletions(-) diff --git a/test/test_sink.py b/test/test_sink.py index 28082891a..5434a8542 100644 --- a/test/test_sink.py +++ b/test/test_sink.py @@ -537,31 +537,8 @@ def testConcurrency(tmp_path): assert np.allclose(data, seed_data) -def compare_metadata(actual, expected): - assert type(actual) is type(expected) - if isinstance(actual, list): - for i, v in enumerate(actual): - compare_metadata(v, expected[i]) - elif isinstance(actual, dict): - assert len(actual.keys()) == len(expected.keys()) - for k, v in actual.items(): - compare_metadata(v, expected[k]) - else: - assert actual == expected - - -@pytest.mark.parametrize('use_add_tile_args', [True, False]) -def testFrameValuesAddTile(use_add_tile_args, tmp_path): - output_file = tmp_path / 'test.db' - sink = large_image_source_zarr.new() - - frame_shape = (300, 400, 3) - expected = dict( - z=dict(values=[2, 4, 6, 8], uniform=True, units='meter', stride=9, dtype='int64'), - t=dict(values=[10.0, 20.0, 30.0], uniform=False, units='millisecond', stride=3, dtype='float64'), - c=dict(values=['r', 'g', 'b'], uniform=True, units='channel', stride=1, dtype='str32'), - ) - expected_metadata = dict( +def get_expected_metadata(axis_spec, frame_shape): + return dict( levels=1, sizeY=frame_shape[0], sizeX=frame_shape[1], @@ -573,13 +550,13 @@ def testFrameValuesAddTile(use_add_tile_args, tmp_path): mm_x=0, mm_y=0, dtype='float64', - channels=[f'Band {c + 1}' for c in range(len(expected['c']['values']))], - channelmap={f'Band {c + 1}': c for c in range(len(expected['c']['values']))}, + channels=[f'Band {c + 1}' for c in range(len(axis_spec['c']['values']))], + channelmap={f'Band {c + 1}': c for c in range(len(axis_spec['c']['values']))}, IndexRange={ - f'Index{k.upper()}': len(v['values']) for k, v in expected.items() + f'Index{k.upper()}': len(v['values']) for k, v in axis_spec.items() }, IndexStride={ - f'Index{k.upper()}': v['stride'] for k, v in expected.items() + f'Index{k.upper()}': v['stride'] for k, v in axis_spec.items() }, **{ f'Value{k.upper()}': dict( @@ -589,27 +566,105 @@ def testFrameValuesAddTile(use_add_tile_args, tmp_path): min=min(v['values']), max=max(v['values']), datatype=v['dtype'], - ) for k, v in expected.items() + ) for k, v in axis_spec.items() } ) - sink.frameAxes = list(expected.keys()) +def compare_metadata(actual, expected): + assert type(actual) is type(expected) + if isinstance(actual, list): + for i, v in enumerate(actual): + compare_metadata(v, expected[i]) + elif isinstance(actual, dict): + assert len(actual.keys()) == len(expected.keys()) + for k, v in actual.items(): + compare_metadata(v, expected[k]) + else: + assert actual == expected + + +@pytest.mark.parametrize('use_add_tile_args', [True, False]) +def testFrameValuesSmall(use_add_tile_args, tmp_path): + output_file = tmp_path / 'test.db' + sink = large_image_source_zarr.new() + + frame_shape = (300, 400, 3) + axis_spec = dict( + c=dict(values=['r', 'g', 'b'], uniform=True, units='channel', stride=1, dtype='str32'), + ) + expected_metadata = get_expected_metadata(axis_spec, frame_shape) + + sink.frameAxes = list(axis_spec.keys()) + sink.frameUnits = { + k: v['units'] for k, v in axis_spec.items() + } + frame_values_shape = [ + *[len(v['values']) for v in axis_spec.values()], + len(axis_spec) + ] + frame_values = np.empty(frame_values_shape, dtype=object) + + frame = 0 + index = 0 + for c, c_value in enumerate(axis_spec['c']['values']): + add_tile_args = dict(c=c, axes=['c', 'y', 'x', 's']) + if use_add_tile_args: + add_tile_args.update(c_value=c_value) + else: + frame_values[c] = [c_value] + random_tile = np.random.random(frame_shape) + sink.addTile(random_tile, 0, 0, **add_tile_args) + expected_metadata['frames'].append( + dict( + Frame=frame, + Index=index, + IndexC=c, + ValueC=c_value, + Channel=f'Band {c + 1}', + ) + ) + frame += 1 + index += 1 + + if not use_add_tile_args: + sink.frameValues = frame_values + compare_metadata(dict(sink.getMetadata()), expected_metadata) + + sink.write(output_file) + written = large_image_source_zarr.open(output_file) + compare_metadata(dict(written.getMetadata()), expected_metadata) + + +@pytest.mark.parametrize('use_add_tile_args', [True, False]) +def testFrameValues(use_add_tile_args, tmp_path): + output_file = tmp_path / 'test.db' + sink = large_image_source_zarr.new() + + frame_shape = (300, 400, 3) + axis_spec = dict( + z=dict(values=[2, 4, 6, 8], uniform=True, units='meter', stride=9, dtype='int64'), + t=dict(values=[10.0, 20.0, 30.0], uniform=False, units='millisecond', stride=3, dtype='float64'), + c=dict(values=['r', 'g', 'b'], uniform=True, units='channel', stride=1, dtype='str32'), + ) + expected_metadata = get_expected_metadata(axis_spec, frame_shape) + + sink.frameAxes = list(axis_spec.keys()) sink.frameUnits = { - k: v['units'] for k, v in expected.items() + k: v['units'] for k, v in axis_spec.items() } frame_values_shape = [ - *[len(v['values']) for v in expected.values()], - len(expected) + *[len(v['values']) for v in axis_spec.values()], + len(axis_spec) ] frame_values = np.empty(frame_values_shape, dtype=object) frame = 0 index = 0 - for z, z_value in enumerate(expected['z']['values']): - for t, t_value in enumerate(expected['t']['values']): - if not expected['t']['uniform']: + for z, z_value in enumerate(axis_spec['z']['values']): + for t, t_value in enumerate(axis_spec['t']['values']): + if not axis_spec['t']['uniform']: t_value += 0.01 * z - for c, c_value in enumerate(expected['c']['values']): + for c, c_value in enumerate(axis_spec['c']['values']): add_tile_args = dict(z=z, t=t, c=c, axes=['z', 't', 'c', 'y', 'x', 's']) if use_add_tile_args: add_tile_args.update(z_value=z_value, t_value=t_value, c_value=c_value) From 0aa76b40173aba39183dae63c240965f54de95ed Mon Sep 17 00:00:00 2001 From: Anne Haley Date: Wed, 4 Sep 2024 19:16:31 +0000 Subject: [PATCH 4/5] Reformat, break up complex functions --- .../zarr/large_image_source_zarr/__init__.py | 162 ++++++++++-------- test/test_sink.py | 40 ++++- 2 files changed, 121 insertions(+), 81 deletions(-) diff --git a/sources/zarr/large_image_source_zarr/__init__.py b/sources/zarr/large_image_source_zarr/__init__.py index b717a0524..e0c178ec9 100644 --- a/sources/zarr/large_image_source_zarr/__init__.py +++ b/sources/zarr/large_image_source_zarr/__init__.py @@ -224,10 +224,10 @@ def _scanZarrArray(self, group, arr, results): tuple that is used to find the maximum size array, preferring ome arrays, then total pixels, then channels. 'is_ome' is a boolean. 'series' is a list of the found groups and arrays that match the - best criteria. 'axes', 'axes_values', 'axes_units', and 'channels' are from the best array. - 'associated' is a list of all groups and arrays that might be - associated images. These have to be culled for the actual groups - used in the series. + best criteria. 'axes', 'axes_values', 'axes_units', and 'channels' + are from the best array. 'associated' is a list of all groups and + arrays that might be associated images. These have to be culled + for the actual groups used in the series. """ attrs = group.attrs.asdict() if group is not None else {} min_version = packaging.version.Version('0.4') @@ -244,8 +244,14 @@ def _scanZarrArray(self, group, arr, results): if is_ome: axes = {axis['name']: idx for idx, axis in enumerate( attrs['multiscales'][0]['axes'])} - axes_values = {axis['name']: axis.get('values') for axis in attrs['multiscales'][0]['axes']} - axes_units = {axis['name']: axis.get('unit') for axis in attrs['multiscales'][0]['axes']} + axes_values = { + axis['name']: axis.get('values') + for axis in attrs['multiscales'][0]['axes'] + } + axes_units = { + axis['name']: axis.get('unit') + for axis in attrs['multiscales'][0]['axes'] + } if isinstance(attrs['omero'].get('channels'), list): channels = [channel['label'] for channel in attrs['omero']['channels']] if all(channel.startswith('Channel ') for channel in channels): @@ -345,6 +351,43 @@ def _getScale(self): except Exception: pass + def _readFrameValues(self, found, baseArray): + axes_values = found.get('axes_values') + axes_units = found.get('axes_units') + if isinstance(axes_values, dict): + self._frameAxes = [ + a for a, i in sorted(self._axes.items(), key=lambda x: x[1]) + if axes_values.get(a) is not None + ] + self._frameUnits = {k: axes_units.get(k) for k in self.frameAxes if k in axes_units} + frame_values_shape = [baseArray.shape[self._axes[a]] for a in self.frameAxes] + frame_values_shape.append(len(frame_values_shape)) + frame_values = np.empty(frame_values_shape, dtype=object) + for axis, values in axes_values.items(): + if axis in self.frameAxes: + slicing = [slice(None) for i in range(len(frame_values_shape))] + axis_index = self.frameAxes.index(axis) + slicing[-1] = axis_index + if isinstance(values, list): + # uniform values are written as lists + for i, value in enumerate(values): + slicing[axis_index] = i + frame_values[tuple(slicing)] = value + elif isinstance(values, dict): + # non-uniform values are written as dicts + # mapping values to index permutations + for value, frame_specs in values.items(): + if isinstance(value, str): + try: + value = float(value) if '.' in value else int(value) + except Exception: + pass + for frame_spec in frame_specs: + for a, i in frame_spec.items(): + slicing[self.frameAxes.index(a)] = i + frame_values[tuple(slicing)] = value + self._frameValues = frame_values + def _validateZarr(self): """ Validate that we can read tiles from the zarr parent group in @@ -410,40 +453,7 @@ def _validateZarr(self): self._axisCounts['xy'] = len(self._series) stride *= len(self._series) self._framecount = stride - axes_values = found.get('axes_values') - axes_units = found.get('axes_units') - if isinstance(axes_values, dict): - self._frameAxes = [ - a for a, i in sorted(self._axes.items(), key=lambda x: x[1]) - if axes_values.get(a) is not None - ] - self._frameUnits = {k: axes_units.get(k) for k in self.frameAxes if k in axes_units} - frame_values_shape = [baseArray.shape[self._axes[a]] for a in self.frameAxes] - frame_values_shape.append(len(frame_values_shape)) - frame_values = np.empty(frame_values_shape, dtype=object) - for axis, values in axes_values.items(): - if axis in self.frameAxes: - slicing = [slice(None) for i in range(len(frame_values_shape))] - axis_index = self.frameAxes.index(axis) - slicing[-1] = axis_index - if isinstance(values, list): - # uniform values are written as lists - for i, value in enumerate(values): - slicing[axis_index] = i - frame_values[tuple(slicing)] = value - elif isinstance(values, dict): - # non-uniform values are written as dicts mapping values to index permutations - for value, frame_specs in values.items(): - if isinstance(value, str): - try: - value = float(value) if '.' in value else int(value) - except: - pass - for frame_spec in frame_specs: - for a, i in frame_spec.items(): - slicing[self.frameAxes.index(a)] = i - frame_values[tuple(slicing)] = value - self._frameValues = frame_values + self._readFrameValues(found, baseArray) def _nonemptyLevelsList(self, frame=0): """ @@ -502,7 +512,7 @@ def getMetadata(self): units=self.frameUnits.get(axis) if self.frameUnits is not None else None, min=min(values), max=max(values), - datatype=np.array(values).dtype.name + datatype=np.array(values).dtype.name, ) for idx in range(self._framecount): frame = {'Frame': idx} @@ -697,7 +707,7 @@ def addTile(self, tile, x=0, y=0, mask=None, axes=None, **kwargs): elif self.frameValues.shape != frames_shape: self.frameValues = np.pad( self.frameValues, - [(0, s - self.frameValues.shape[i]) for i, s in enumerate(frames_shape)] + [(0, s - self.frameValues.shape[i]) for i, s in enumerate(frames_shape)], ) current_frame_slice = tuple(placement.get(a) for a in self.frameAxes) for i, k in enumerate(self.frameAxes): @@ -778,6 +788,39 @@ def addAssociatedImage(self, image, imageKey=None): ) self._associatedImages[imageKey] = (group, arr) + def _getAxisInternalMetadata(self, axis_name): + axis_metadata = {'name': axis_name} + if axis_name in ['x', 'y']: + axis_metadata['type'] = 'space' + axis_metadata['unit'] = 'millimeter' + elif axis_name in ['s', 'c']: + axis_metadata['type'] = 'channel' + if self.frameAxes is not None: + frame_axis_index = self.frameAxes.index(axis_name) if axis_name in self.frameAxes else None + if frame_axis_index is not None and self.frameValues is not None: + all_frame_values = self.frameValues[..., frame_axis_index] + split = np.split( + all_frame_values, + all_frame_values.shape[frame_axis_index], + axis=frame_axis_index, + ) + uniform = all(len(np.unique(a)) == 1 for a in split) + if uniform: + values = [a.flat[0] for a in split] + else: + values = {} + for index, value in np.ndenumerate(all_frame_values): + if value not in values: + values[value] = [] + values[value].append({ + axis: index[i] for i, axis in enumerate(self.frameAxes) + }) + axis_metadata['values'] = values + unit = self.frameUnits.get(axis_name) if self.frameUnits is not None else None + if unit is not None: + axis_metadata['unit'] = unit + return axis_metadata + def _writeInternalMetadata(self): self._checkEditable() with self._threadLock and self._processLock: @@ -803,37 +846,11 @@ def _writeInternalMetadata(self): } datasets.append(dataset_metadata) for a in sorted_axes: - axis_metadata = {'name': a} - if a in ['x', 'y']: - axis_metadata['type'] = 'space' - axis_metadata['unit'] = 'millimeter' - elif a in ['s', 'c']: - axis_metadata['type'] = 'channel' - elif a == 't': + if a == 't': rdefs['defaultT'] = 0 elif a == 'z': rdefs['defaultZ'] = 0 - if self.frameAxes is not None: - frame_axis_index = self.frameAxes.index(a) if a in self.frameAxes else None - if frame_axis_index is not None and self.frameValues is not None: - all_frame_values = self.frameValues[..., frame_axis_index] - split = np.split(all_frame_values, all_frame_values.shape[frame_axis_index], axis=frame_axis_index) - uniform = all(len(np.unique(a)) == 1 for a in split) - if uniform: - values = [a.flat[0] for a in split] - else: - values = {} - for index, value in np.ndenumerate(all_frame_values): - if value not in values: - values[value] = [] - values[value].append({ - axis: index[i] for i, axis in enumerate(self.frameAxes) - }) - axis_metadata['values'] = values - unit = self.frameUnits.get(a) if self.frameUnits is not None else None - if unit is not None: - axis_metadata['unit'] = unit - axes.append(axis_metadata) + axes.append(self._getAxisInternalMetadata(a)) if channel_axis is not None and len(arrays) > 0: base_array = list(arrays.values())[0] base_shape = base_array.shape @@ -986,11 +1003,12 @@ def frameUnits(self, units): self._checkEditable() if self.frameAxes is None: err = 'frameAxes must be set first with a list of frame axis names.' - raise ValueError + raise ValueError(err) if not isinstance(units, dict) or not all( k in self.frameAxes for k in units.keys() ): err = 'frameUnits must be a dictionary with keys that exist in frameAxes.' + raise ValueError(err) self._frameUnits = units @property @@ -1002,7 +1020,7 @@ def frameValues(self, a): self._checkEditable() if self.frameAxes is None: err = 'frameAxes must be set first with a list of frame axis names.' - raise ValueError + raise ValueError(err) if len(a.shape) != len(self.frameAxes) + 1: err = f'frameValues must have {len(self.frameAxes) + 1} dimensions.' raise ValueError(err) diff --git a/test/test_sink.py b/test/test_sink.py index 5434a8542..e28b4b929 100644 --- a/test/test_sink.py +++ b/test/test_sink.py @@ -567,7 +567,7 @@ def get_expected_metadata(axis_spec, frame_shape): max=max(v['values']), datatype=v['dtype'], ) for k, v in axis_spec.items() - } + }, ) def compare_metadata(actual, expected): @@ -589,9 +589,13 @@ def testFrameValuesSmall(use_add_tile_args, tmp_path): sink = large_image_source_zarr.new() frame_shape = (300, 400, 3) - axis_spec = dict( - c=dict(values=['r', 'g', 'b'], uniform=True, units='channel', stride=1, dtype='str32'), - ) + axis_spec = dict(c=dict( + values=['r', 'g', 'b'], + uniform=True, + units='channel', + stride=1, + dtype='str32', + )) expected_metadata = get_expected_metadata(axis_spec, frame_shape) sink.frameAxes = list(axis_spec.keys()) @@ -600,7 +604,7 @@ def testFrameValuesSmall(use_add_tile_args, tmp_path): } frame_values_shape = [ *[len(v['values']) for v in axis_spec.values()], - len(axis_spec) + len(axis_spec), ] frame_values = np.empty(frame_values_shape, dtype=object) @@ -621,7 +625,7 @@ def testFrameValuesSmall(use_add_tile_args, tmp_path): IndexC=c, ValueC=c_value, Channel=f'Band {c + 1}', - ) + ), ) frame += 1 index += 1 @@ -642,9 +646,27 @@ def testFrameValues(use_add_tile_args, tmp_path): frame_shape = (300, 400, 3) axis_spec = dict( - z=dict(values=[2, 4, 6, 8], uniform=True, units='meter', stride=9, dtype='int64'), - t=dict(values=[10.0, 20.0, 30.0], uniform=False, units='millisecond', stride=3, dtype='float64'), - c=dict(values=['r', 'g', 'b'], uniform=True, units='channel', stride=1, dtype='str32'), + z=dict( + values=[2, 4, 6, 8], + uniform=True, + units='meter', + stride=9, + dtype='int64', + ), + t=dict( + values=[10.0, 20.0, 30.0], + uniform=False, + units='millisecond', + stride=3, + dtype='float64', + ), + c=dict( + values=['r', 'g', 'b'], + uniform=True, + units='channel', + stride=1, + dtype='str32', + ), ) expected_metadata = get_expected_metadata(axis_spec, frame_shape) From 6c00eed5af2d7312bbca6638aa36c99b49dcb6fb Mon Sep 17 00:00:00 2001 From: Anne Haley Date: Wed, 4 Sep 2024 19:24:15 +0000 Subject: [PATCH 5/5] Fix styling --- .../zarr/large_image_source_zarr/__init__.py | 26 +++--- test/test_sink.py | 79 ++++++++++--------- 2 files changed, 53 insertions(+), 52 deletions(-) diff --git a/sources/zarr/large_image_source_zarr/__init__.py b/sources/zarr/large_image_source_zarr/__init__.py index e0c178ec9..621f1db9a 100644 --- a/sources/zarr/large_image_source_zarr/__init__.py +++ b/sources/zarr/large_image_source_zarr/__init__.py @@ -225,8 +225,8 @@ def _scanZarrArray(self, group, arr, results): arrays, then total pixels, then channels. 'is_ome' is a boolean. 'series' is a list of the found groups and arrays that match the best criteria. 'axes', 'axes_values', 'axes_units', and 'channels' - are from the best array. 'associated' is a list of all groups and - arrays that might be associated images. These have to be culled + are from the best array. 'associated' is a list of all groups and + arrays that might be associated images. These have to be culled for the actual groups used in the series. """ attrs = group.attrs.asdict() if group is not None else {} @@ -245,11 +245,11 @@ def _scanZarrArray(self, group, arr, results): axes = {axis['name']: idx for idx, axis in enumerate( attrs['multiscales'][0]['axes'])} axes_values = { - axis['name']: axis.get('values') + axis['name']: axis.get('values') for axis in attrs['multiscales'][0]['axes'] } axes_units = { - axis['name']: axis.get('unit') + axis['name']: axis.get('unit') for axis in attrs['multiscales'][0]['axes'] } if isinstance(attrs['omero'].get('channels'), list): @@ -374,7 +374,7 @@ def _readFrameValues(self, found, baseArray): slicing[axis_index] = i frame_values[tuple(slicing)] = value elif isinstance(values, dict): - # non-uniform values are written as dicts + # non-uniform values are written as dicts # mapping values to index permutations for value, frame_specs in values.items(): if isinstance(value, str): @@ -387,7 +387,7 @@ def _readFrameValues(self, found, baseArray): slicing[self.frameAxes.index(a)] = i frame_values[tuple(slicing)] = value self._frameValues = frame_values - + def _validateZarr(self): """ Validate that we can read tiles from the zarr parent group in @@ -796,13 +796,13 @@ def _getAxisInternalMetadata(self, axis_name): elif axis_name in ['s', 'c']: axis_metadata['type'] = 'channel' if self.frameAxes is not None: - frame_axis_index = self.frameAxes.index(axis_name) if axis_name in self.frameAxes else None - if frame_axis_index is not None and self.frameValues is not None: - all_frame_values = self.frameValues[..., frame_axis_index] + axis_index = self.frameAxes.index(axis_name) if axis_name in self.frameAxes else None + if axis_index is not None and self.frameValues is not None: + all_frame_values = self.frameValues[..., axis_index] split = np.split( all_frame_values, - all_frame_values.shape[frame_axis_index], - axis=frame_axis_index, + all_frame_values.shape[axis_index], + axis=axis_index, ) uniform = all(len(np.unique(a)) == 1 for a in split) if uniform: @@ -819,8 +819,8 @@ def _getAxisInternalMetadata(self, axis_name): unit = self.frameUnits.get(axis_name) if self.frameUnits is not None else None if unit is not None: axis_metadata['unit'] = unit - return axis_metadata - + return axis_metadata + def _writeInternalMetadata(self): self._checkEditable() with self._threadLock and self._processLock: diff --git a/test/test_sink.py b/test/test_sink.py index e28b4b929..e2634d1dd 100644 --- a/test/test_sink.py +++ b/test/test_sink.py @@ -570,6 +570,7 @@ def get_expected_metadata(axis_spec, frame_shape): }, ) + def compare_metadata(actual, expected): assert type(actual) is type(expected) if isinstance(actual, list): @@ -611,23 +612,23 @@ def testFrameValuesSmall(use_add_tile_args, tmp_path): frame = 0 index = 0 for c, c_value in enumerate(axis_spec['c']['values']): - add_tile_args = dict(c=c, axes=['c', 'y', 'x', 's']) - if use_add_tile_args: - add_tile_args.update(c_value=c_value) - else: - frame_values[c] = [c_value] - random_tile = np.random.random(frame_shape) - sink.addTile(random_tile, 0, 0, **add_tile_args) - expected_metadata['frames'].append( - dict( - Frame=frame, - Index=index, - IndexC=c, - ValueC=c_value, - Channel=f'Band {c + 1}', - ), - ) - frame += 1 + add_tile_args = dict(c=c, axes=['c', 'y', 'x', 's']) + if use_add_tile_args: + add_tile_args.update(c_value=c_value) + else: + frame_values[c] = [c_value] + random_tile = np.random.random(frame_shape) + sink.addTile(random_tile, 0, 0, **add_tile_args) + expected_metadata['frames'].append( + dict( + Frame=frame, + Index=index, + IndexC=c, + ValueC=c_value, + Channel=f'Band {c + 1}', + ), + ) + frame += 1 index += 1 if not use_add_tile_args: @@ -676,7 +677,7 @@ def testFrameValues(use_add_tile_args, tmp_path): } frame_values_shape = [ *[len(v['values']) for v in axis_spec.values()], - len(axis_spec) + len(axis_spec), ] frame_values = np.empty(frame_values_shape, dtype=object) @@ -687,27 +688,27 @@ def testFrameValues(use_add_tile_args, tmp_path): if not axis_spec['t']['uniform']: t_value += 0.01 * z for c, c_value in enumerate(axis_spec['c']['values']): - add_tile_args = dict(z=z, t=t, c=c, axes=['z', 't', 'c', 'y', 'x', 's']) - if use_add_tile_args: - add_tile_args.update(z_value=z_value, t_value=t_value, c_value=c_value) - else: - frame_values[z, t, c] = [z_value, t_value, c_value] - random_tile = np.random.random(frame_shape) - sink.addTile(random_tile, 0, 0, **add_tile_args) - expected_metadata['frames'].append( - dict( - Frame=frame, - Index=index, - IndexZ=z, - ValueZ=z_value, - IndexT=t, - ValueT=t_value, - IndexC=c, - ValueC=c_value, - Channel=f'Band {c + 1}', - ) - ) - frame += 1 + add_tile_args = dict(z=z, t=t, c=c, axes=['z', 't', 'c', 'y', 'x', 's']) + if use_add_tile_args: + add_tile_args.update(z_value=z_value, t_value=t_value, c_value=c_value) + else: + frame_values[z, t, c] = [z_value, t_value, c_value] + random_tile = np.random.random(frame_shape) + sink.addTile(random_tile, 0, 0, **add_tile_args) + expected_metadata['frames'].append( + dict( + Frame=frame, + Index=index, + IndexZ=z, + ValueZ=z_value, + IndexT=t, + ValueT=t_value, + IndexC=c, + ValueC=c_value, + Channel=f'Band {c + 1}', + ), + ) + frame += 1 index += 1 if not use_add_tile_args: