Skip to content

Update to curation format v2 #157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 27 additions & 16 deletions spikeinterface_gui/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
import spikeinterface.qualitymetrics
from spikeinterface.core.sorting_tools import spike_vector_to_indices
from spikeinterface.core.core_tools import check_json
from spikeinterface.curation import validate_curation_dict
from spikeinterface.widgets.utils import make_units_table_from_analyzer

from .curation_tools import adding_group, default_label_definitions, empty_curation_data
from .curation_tools import add_merge, default_label_definitions, empty_curation_data

spike_dtype =[('sample_index', 'int64'), ('unit_index', 'int64'),
('channel_index', 'int64'), ('segment_index', 'int64'),
Expand Down Expand Up @@ -319,7 +320,17 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save
if curation_data is None:
self.curation_data = empty_curation_data.copy()
else:
self.curation_data = curation_data
# validate the curation data
format_version = curation_data.get("format_version", None)
# assume version 2 if not present
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I would not do this. This will be a mess with previous curation files.
Lets be strict and accept only dict that do not the format_version, this is part of the spec. no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No curation format is needed for validation!

What we can do is try 2 and then 1. Ok?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add it to the test curation. The controller should fail if not provided

if format_version is None:
raise ValueError("Curation data format version is missing and is required in the curation data.")
try:
validate_curation_dict(curation_data)
self.curation_data = curation_data
except Exception as e:
print(f"Invalid curation data. Initializing with empty curation data.\nError: {e}")
self.curation_data = empty_curation_data.copy()

self.has_default_quality_labels = False
if "label_definitions" not in self.curation_data:
Expand Down Expand Up @@ -668,7 +679,7 @@ def curation_can_be_saved(self):

def construct_final_curation(self):
d = dict()
d["format_version"] = "1"
d["format_version"] = "2"
d["unit_ids"] = self.unit_ids.tolist()
d.update(self.curation_data.copy())
return d
Expand Down Expand Up @@ -699,14 +710,14 @@ def make_manual_delete_if_possible(self, removed_unit_ids):
if not self.curation:
return

all_merged_units = sum(self.curation_data["merge_unit_groups"], [])
all_merged_units = sum([m["unit_ids"] for m in self.curation_data["merges"]], [])
for unit_id in removed_unit_ids:
if unit_id in self.curation_data["removed_units"]:
if unit_id in self.curation_data["removed"]:
continue
# TODO: check if unit is already in a merge group
if unit_id in all_merged_units:
continue
self.curation_data["removed_units"].append(unit_id)
self.curation_data["removed"].append(unit_id)
if self.verbose:
print(f"Unit {unit_id} is removed from the curation data")

Expand All @@ -718,10 +729,10 @@ def make_manual_restore(self, restore_unit_ids):
return

for unit_id in restore_unit_ids:
if unit_id in self.curation_data["removed_units"]:
if unit_id in self.curation_data["removed"]:
if self.verbose:
print(f"Unit {unit_id} is restored from the curation data")
self.curation_data["removed_units"].remove(unit_id)
self.curation_data["removed"].remove(unit_id)

def make_manual_merge_if_possible(self, merge_unit_ids):
"""
Expand All @@ -740,22 +751,22 @@ def make_manual_merge_if_possible(self, merge_unit_ids):
return False

for unit_id in merge_unit_ids:
if unit_id in self.curation_data["removed_units"]:
if unit_id in self.curation_data["removed"]:
return False
merged_groups = adding_group(self.curation_data["merge_unit_groups"], merge_unit_ids)
self.curation_data["merge_unit_groups"] = merged_groups

new_merges = add_merge(self.curation_data["merges"], merge_unit_ids)
self.curation_data["merges"] = new_merges
if self.verbose:
print(f"Merged unit group: {merge_unit_ids}")
print(f"Merged unit group: {[str(u) for u in merge_unit_ids]}")
return True

def make_manual_restore_merge(self, merge_group_indices):
if not self.curation:
return
merge_groups_to_remove = [self.curation_data["merge_unit_groups"][merge_group_index] for merge_group_index in merge_group_indices]
for merge_group in merge_groups_to_remove:
for merge_index in merge_group_indices:
if self.verbose:
print(f"Unmerged merge group {merge_group}")
self.curation_data["merge_unit_groups"].remove(merge_group)
print(f"Unmerged merge group {self.curation_data['merge_unit_groups'][merge_index]['unit_ids']}")
self.curation_data["merges"].pop(merge_index)

def get_curation_label_definitions(self):
# give only label definition with exclusive
Expand Down
32 changes: 17 additions & 15 deletions spikeinterface_gui/curation_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,29 @@

empty_curation_data = {
"manual_labels": [],
"merge_unit_groups": [],
"removed_units": []
"merges": [],
"splits": [],
"removes": []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

format_version="2"

}

def adding_group(previous_groups, new_group):
def add_merge(previous_merges, new_merge_unit_ids):
# this is to ensure that np.str_ types are rendered as str
to_merge = [np.array(new_group).tolist()]
to_merge = [np.array(new_merge_unit_ids).tolist()]
unchanged = []
for c_prev in previous_groups:
for c_prev in previous_merges:
is_unaffected = True

for c_new in new_group:
if c_new in c_prev:
c_prev_unit_ids = c_prev["unit_ids"]
for c_new in new_merge_unit_ids:
if c_new in c_prev_unit_ids:
is_unaffected = False
to_merge.append(c_prev)
to_merge.append(c_prev_unit_ids)
break

if is_unaffected:
unchanged.append(c_prev)
new_merge_group = [sum(to_merge, [])]
new_merge_group.extend(unchanged)
# Ensure the unicity
new_merge_group = [list(set(gp)) for gp in new_merge_group]
return new_merge_group
unchanged.append(c_prev_unit_ids)

new_merge_units = [sum(to_merge, [])]
new_merge_units.extend(unchanged)
# Ensure the uniqueness
new_merges = [{"unit_ids": list(set(gp))} for gp in new_merge_units]
return new_merges
12 changes: 6 additions & 6 deletions spikeinterface_gui/curationview.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _qt_make_layout(self):
def _qt_refresh(self):
from .myqt import QT
# Merged
merged_units = self.controller.curation_data["merge_unit_groups"]
merged_units = [m["unit_ids"] for m in self.controller.curation_data["merges"]]
self.table_merge.clear()
self.table_merge.setRowCount(len(merged_units))
self.table_merge.setColumnCount(1)
Expand All @@ -115,7 +115,7 @@ def _qt_refresh(self):
self.table_merge.resizeColumnToContents(i)

## deleted
removed_units = self.controller.curation_data["removed_units"]
removed_units = self.controller.curation_data["removed"]
self.table_delete.clear()
self.table_delete.setRowCount(len(removed_units))
self.table_delete.setColumnCount(1)
Expand Down Expand Up @@ -161,7 +161,7 @@ def _qt_on_item_selection_changed_merge(self):

dtype = self.controller.unit_ids.dtype
ind = self.table_merge.selectedIndexes()[0].row()
visible_unit_ids = self.controller.curation_data["merge_unit_groups"][ind]
visible_unit_ids = [m["unit_ids"] for m in self.controller.curation_data["merges"]][ind]
visible_unit_ids = [dtype.type(unit_id) for unit_id in visible_unit_ids]
self.controller.set_visible_unit_ids(visible_unit_ids)
self.notify_unit_visibility_changed()
Expand All @@ -170,7 +170,7 @@ def _qt_on_item_selection_changed_delete(self):
if len(self.table_delete.selectedIndexes()) == 0:
return
ind = self.table_delete.selectedIndexes()[0].row()
unit_id = self.controller.curation_data["removed_units"][ind]
unit_id = self.controller.curation_data["removed"][ind]
self.controller.set_all_unit_visibility_off()
# convert to the correct type
unit_id = self.controller.unit_ids.dtype.type(unit_id)
Expand Down Expand Up @@ -332,7 +332,7 @@ def _panel_make_layout(self):
def _panel_refresh(self):
import pandas as pd
# Merged
merged_units = self.controller.curation_data["merge_unit_groups"]
merged_units = [m["unit_ids"] for m in self.controller.curation_data["merges"]]

# for visualization, we make all row entries strings
merged_units_str = []
Expand All @@ -345,7 +345,7 @@ def _panel_refresh(self):
self.table_merge.selection = []

## deleted
removed_units = self.controller.curation_data["removed_units"]
removed_units = self.controller.curation_data["removed"]
removed_units = [str(unit_id) for unit_id in removed_units]
df = pd.DataFrame({"deleted_unit_id": removed_units})
self.table_delete.value = df
Expand Down
2 changes: 1 addition & 1 deletion spikeinterface_gui/mergeview.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def get_table_data(self, include_deleted=False):
unit_ids = list(self.controller.unit_ids)
for group_ids in self.proposed_merge_unit_groups:
if not include_deleted and self.controller.curation:
deleted_unit_ids = self.controller.curation_data["removed_units"]
deleted_unit_ids = self.controller.curation_data["removed"]
if any(unit_id in deleted_unit_ids for unit_id in group_ids):
continue

Expand Down
4 changes: 2 additions & 2 deletions spikeinterface_gui/tests/test_mainwindow_qt.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_launcher(verbose=True):
if __name__ == '__main__':
if not test_folder.is_dir():
setup_module()
# win = test_mainwindow(start_app=True, verbose=True, curation=True)
win = test_mainwindow(start_app=True, verbose=True, curation=True)
# win = test_mainwindow(start_app=True, verbose=True, curation=False)

test_launcher(verbose=True)
# test_launcher(verbose=True)
5 changes: 3 additions & 2 deletions spikeinterface_gui/tests/testingtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def make_analyzer_folder(test_folder, case="small", unit_dtype="str"):
def make_curation_dict(analyzer):
unit_ids = analyzer.unit_ids.tolist()
curation_dict = {
"format_version": "2",
"unit_ids": unit_ids,
"label_definitions": {
"quality":{
Expand All @@ -153,8 +154,8 @@ def make_curation_dict(analyzer):
{'unit_id': unit_ids[2], "putative_type": ["exitatory"]},
{'unit_id': unit_ids[3], "quality": ["noise"], "putative_type": ["inhibitory"]},
],
"merge_unit_groups": [unit_ids[:3], unit_ids[3:5]],
"removed_units": unit_ids[5:8],
"merges": [{"unit_ids": unit_ids[:3]}, {"unit_ids": unit_ids[3:5]}],
"removed": unit_ids[5:8],
}
return curation_dict

Expand Down