Skip to content

Commit de0ee7e

Browse files
authored
Merge pull request #90 from HYPERNETS/flag_ops
add flag mask generating methods
2 parents 3c73ab3 + 5ae0035 commit de0ee7e

File tree

2 files changed

+108
-2
lines changed

2 files changed

+108
-2
lines changed

hypernets_processor/data_io/dataset_util.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def create_default_array(dim_sizes, dtype, dim_names=None, fill_value=None):
5353

5454
if dim_names is not None:
5555
default_array = DataArray(empty_array, dims=dim_names)
56+
elif (dim_names is None) and (dim_sizes == []):
57+
default_array = DataArray(empty_array)
5658
else:
5759
default_array = DataArray(empty_array, dims=DEFAULT_DIM_NAMES[-len(dim_sizes):])
5860

@@ -259,6 +261,50 @@ def unpack_flags(da):
259261

260262
return ds
261263

264+
@staticmethod
265+
def get_flags_mask_or(da, flags=None):
266+
"""
267+
Returns boolean mask for set of flags, defined as logical or of flags
268+
269+
:type da: xarray.DataArray
270+
:param da: dataset
271+
272+
:type flags: list
273+
:param flags: list of flags (if unset all data flags selected)
274+
275+
:return: flag masks
276+
:rtype: numpy.ndarray
277+
"""
278+
279+
flags_ds = DatasetUtil.unpack_flags(da)
280+
281+
flags = flags if flags is not None else flags_ds.variables
282+
mask_flags = [flags_ds[flag].values for flag in flags]
283+
284+
return np.logical_or.reduce(mask_flags)
285+
286+
@staticmethod
287+
def get_flags_mask_and(da, flags=None):
288+
"""
289+
Returns boolean mask for set of flags, defined as logical and of flags
290+
291+
:type da: xarray.DataArray
292+
:param da: dataset
293+
294+
:type flags: list
295+
:param flags: list of flags (if unset all data flags selected)
296+
297+
:return: flag masks
298+
:rtype: numpy.ndarray
299+
"""
300+
301+
flags_ds = DatasetUtil.unpack_flags(da)
302+
303+
flags = flags if flags is not None else flags_ds.variables
304+
mask_flags = [flags_ds[flag].values for flag in flags]
305+
306+
return np.logical_and.reduce(mask_flags)
307+
262308
@staticmethod
263309
def set_flag(da, flag_name, error_if_set=False):
264310
"""
@@ -281,7 +327,9 @@ def set_flag(da, flag_name, error_if_set=False):
281327
flag_bit = flag_meanings.index(flag_name)
282328
flag_mask = flag_masks[flag_bit]
283329

284-
return da | flag_mask
330+
da.values = da.values | flag_mask
331+
332+
return da
285333

286334
@staticmethod
287335
def unset_flag(da, flag_name, error_if_unset=False):
@@ -305,7 +353,9 @@ def unset_flag(da, flag_name, error_if_unset=False):
305353
flag_bit = flag_meanings.index(flag_name)
306354
flag_mask = flag_masks[flag_bit]
307355

308-
return da & ~flag_mask
356+
da.values = da.values & ~flag_mask
357+
358+
return da
309359

310360
@staticmethod
311361
def get_set_flags(da):

hypernets_processor/data_io/tests/test_dataset_util.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,62 @@ def test_unpack_flags(self):
288288
self.assertTrue((flags["flag7"].data == empty).all())
289289
self.assertTrue((flags["flag8"].data == empty).all())
290290

291+
def test_get_flags_mask_or(self):
292+
293+
ds = Dataset()
294+
meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"]
295+
flags_vector_variable = DatasetUtil.create_flags_variable([2,3], meanings, dim_names=["dim1", "dim2"],
296+
attributes={"standard_name": "std"})
297+
298+
ds["flags"] = flags_vector_variable
299+
ds["flags"] = DatasetUtil.set_flag(ds["flags"], "flag4")
300+
ds["flags"][0, 1] = DatasetUtil.set_flag(ds["flags"][0, 1], "flag5")
301+
ds["flags"][1, 1] = DatasetUtil.set_flag(ds["flags"][1, 1], "flag2")
302+
ds["flags"][1, 2] = DatasetUtil.set_flag(ds["flags"][1, 2], "flag7")
303+
304+
flags_mask = DatasetUtil.get_flags_mask_or(ds["flags"], flags=["flag5", "flag2", "flag7"])
305+
306+
expected_flags_mask = np.array([[False, True, False], [False, True, True]], dtype=bool)
307+
308+
np.testing.assert_array_almost_equal(flags_mask, expected_flags_mask)
309+
310+
def test_get_flags_mask_all(self):
311+
ds = Dataset()
312+
meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"]
313+
flags_vector_variable = DatasetUtil.create_flags_variable([2, 3], meanings, dim_names=["dim1", "dim2"],
314+
attributes={"standard_name": "std"})
315+
316+
ds["flags"] = flags_vector_variable
317+
ds["flags"] = DatasetUtil.set_flag(ds["flags"], "flag4")
318+
ds["flags"][0, 1] = DatasetUtil.set_flag(ds["flags"][0, 1], "flag5")
319+
ds["flags"][1, 1] = DatasetUtil.set_flag(ds["flags"][1, 1], "flag2")
320+
ds["flags"][1, 2] = DatasetUtil.set_flag(ds["flags"][1, 2], "flag7")
321+
322+
flags_mask = DatasetUtil.get_flags_mask_or(ds["flags"])
323+
324+
expected_flags_mask = np.array([[True, True, True], [True, True, True]], dtype=bool)
325+
326+
np.testing.assert_array_almost_equal(flags_mask, expected_flags_mask)
327+
328+
def test_get_flags_mask_and(self):
329+
330+
ds = Dataset()
331+
meanings = ["flag1", "flag2", "flag3", "flag4", "flag5", "flag6", "flag7", "flag8"]
332+
flags_vector_variable = DatasetUtil.create_flags_variable([2,3], meanings, dim_names=["dim1", "dim2"],
333+
attributes={"standard_name": "std"})
334+
335+
ds["flags"] = flags_vector_variable
336+
ds["flags"] = DatasetUtil.set_flag(ds["flags"], "flag4")
337+
ds["flags"][0, 1] = DatasetUtil.set_flag(ds["flags"][0, 1], "flag5")
338+
ds["flags"][1, 1] = DatasetUtil.set_flag(ds["flags"][1, 1], "flag2")
339+
ds["flags"][1, 1] = DatasetUtil.set_flag(ds["flags"][1, 1], "flag7")
340+
341+
flags_mask = DatasetUtil.get_flags_mask_and(ds["flags"], flags=["flag2", "flag7"])
342+
343+
expected_flags_mask = np.array([[False, False, False], [False, True, False]], dtype=bool)
344+
345+
np.testing.assert_array_almost_equal(flags_mask, expected_flags_mask)
346+
291347
def test_get_set_flags(self):
292348

293349
ds = Dataset()

0 commit comments

Comments
 (0)