diff --git a/spikeinterface_gui/controller.py b/spikeinterface_gui/controller.py index ad653e6..4f4135b 100644 --- a/spikeinterface_gui/controller.py +++ b/spikeinterface_gui/controller.py @@ -12,9 +12,10 @@ import spikeinterface.qualitymetrics from spikeinterface.core.sorting_tools import spike_vector_to_indices from spikeinterface.core.core_tools import check_json +from spikeinterface.curation import validate_curation_dict from spikeinterface.widgets.utils import make_units_table_from_analyzer -from .curation_tools import adding_group, default_label_definitions, empty_curation_data +from .curation_tools import add_merge, default_label_definitions, empty_curation_data spike_dtype =[('sample_index', 'int64'), ('unit_index', 'int64'), ('channel_index', 'int64'), ('segment_index', 'int64'), @@ -319,7 +320,17 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save if curation_data is None: self.curation_data = empty_curation_data.copy() else: - self.curation_data = curation_data + # validate the curation data + format_version = curation_data.get("format_version", None) + # assume version 2 if not present + if format_version is None: + raise ValueError("Curation data format version is missing and is required in the curation data.") + try: + validate_curation_dict(curation_data) + self.curation_data = curation_data + except Exception as e: + print(f"Invalid curation data. Initializing with empty curation data.\nError: {e}") + self.curation_data = empty_curation_data.copy() self.has_default_quality_labels = False if "label_definitions" not in self.curation_data: @@ -668,7 +679,7 @@ def curation_can_be_saved(self): def construct_final_curation(self): d = dict() - d["format_version"] = "1" + d["format_version"] = "2" d["unit_ids"] = self.unit_ids.tolist() d.update(self.curation_data.copy()) return d @@ -699,14 +710,14 @@ def make_manual_delete_if_possible(self, removed_unit_ids): if not self.curation: return - all_merged_units = sum(self.curation_data["merge_unit_groups"], []) + all_merged_units = sum([m["unit_ids"] for m in self.curation_data["merges"]], []) for unit_id in removed_unit_ids: - if unit_id in self.curation_data["removed_units"]: + if unit_id in self.curation_data["removed"]: continue # TODO: check if unit is already in a merge group if unit_id in all_merged_units: continue - self.curation_data["removed_units"].append(unit_id) + self.curation_data["removed"].append(unit_id) if self.verbose: print(f"Unit {unit_id} is removed from the curation data") @@ -718,10 +729,10 @@ def make_manual_restore(self, restore_unit_ids): return for unit_id in restore_unit_ids: - if unit_id in self.curation_data["removed_units"]: + if unit_id in self.curation_data["removed"]: if self.verbose: print(f"Unit {unit_id} is restored from the curation data") - self.curation_data["removed_units"].remove(unit_id) + self.curation_data["removed"].remove(unit_id) def make_manual_merge_if_possible(self, merge_unit_ids): """ @@ -740,22 +751,22 @@ def make_manual_merge_if_possible(self, merge_unit_ids): return False for unit_id in merge_unit_ids: - if unit_id in self.curation_data["removed_units"]: + if unit_id in self.curation_data["removed"]: return False - merged_groups = adding_group(self.curation_data["merge_unit_groups"], merge_unit_ids) - self.curation_data["merge_unit_groups"] = merged_groups + + new_merges = add_merge(self.curation_data["merges"], merge_unit_ids) + self.curation_data["merges"] = new_merges if self.verbose: - print(f"Merged unit group: {merge_unit_ids}") + print(f"Merged unit group: {[str(u) for u in merge_unit_ids]}") return True def make_manual_restore_merge(self, merge_group_indices): if not self.curation: return - merge_groups_to_remove = [self.curation_data["merge_unit_groups"][merge_group_index] for merge_group_index in merge_group_indices] - for merge_group in merge_groups_to_remove: + for merge_index in merge_group_indices: if self.verbose: - print(f"Unmerged merge group {merge_group}") - self.curation_data["merge_unit_groups"].remove(merge_group) + print(f"Unmerged merge group {self.curation_data['merge_unit_groups'][merge_index]['unit_ids']}") + self.curation_data["merges"].pop(merge_index) def get_curation_label_definitions(self): # give only label definition with exclusive diff --git a/spikeinterface_gui/curation_tools.py b/spikeinterface_gui/curation_tools.py index 2d92b41..d96964b 100644 --- a/spikeinterface_gui/curation_tools.py +++ b/spikeinterface_gui/curation_tools.py @@ -11,27 +11,29 @@ empty_curation_data = { "manual_labels": [], - "merge_unit_groups": [], - "removed_units": [] + "merges": [], + "splits": [], + "removes": [] } -def adding_group(previous_groups, new_group): +def add_merge(previous_merges, new_merge_unit_ids): # this is to ensure that np.str_ types are rendered as str - to_merge = [np.array(new_group).tolist()] + to_merge = [np.array(new_merge_unit_ids).tolist()] unchanged = [] - for c_prev in previous_groups: + for c_prev in previous_merges: is_unaffected = True - - for c_new in new_group: - if c_new in c_prev: + c_prev_unit_ids = c_prev["unit_ids"] + for c_new in new_merge_unit_ids: + if c_new in c_prev_unit_ids: is_unaffected = False - to_merge.append(c_prev) + to_merge.append(c_prev_unit_ids) break if is_unaffected: - unchanged.append(c_prev) - new_merge_group = [sum(to_merge, [])] - new_merge_group.extend(unchanged) - # Ensure the unicity - new_merge_group = [list(set(gp)) for gp in new_merge_group] - return new_merge_group + unchanged.append(c_prev_unit_ids) + + new_merge_units = [sum(to_merge, [])] + new_merge_units.extend(unchanged) + # Ensure the uniqueness + new_merges = [{"unit_ids": list(set(gp))} for gp in new_merge_units] + return new_merges diff --git a/spikeinterface_gui/curationview.py b/spikeinterface_gui/curationview.py index b6ed141..6fb2314 100644 --- a/spikeinterface_gui/curationview.py +++ b/spikeinterface_gui/curationview.py @@ -101,7 +101,7 @@ def _qt_make_layout(self): def _qt_refresh(self): from .myqt import QT # Merged - merged_units = self.controller.curation_data["merge_unit_groups"] + merged_units = [m["unit_ids"] for m in self.controller.curation_data["merges"]] self.table_merge.clear() self.table_merge.setRowCount(len(merged_units)) self.table_merge.setColumnCount(1) @@ -115,7 +115,7 @@ def _qt_refresh(self): self.table_merge.resizeColumnToContents(i) ## deleted - removed_units = self.controller.curation_data["removed_units"] + removed_units = self.controller.curation_data["removed"] self.table_delete.clear() self.table_delete.setRowCount(len(removed_units)) self.table_delete.setColumnCount(1) @@ -161,7 +161,7 @@ def _qt_on_item_selection_changed_merge(self): dtype = self.controller.unit_ids.dtype ind = self.table_merge.selectedIndexes()[0].row() - visible_unit_ids = self.controller.curation_data["merge_unit_groups"][ind] + visible_unit_ids = [m["unit_ids"] for m in self.controller.curation_data["merges"]][ind] visible_unit_ids = [dtype.type(unit_id) for unit_id in visible_unit_ids] self.controller.set_visible_unit_ids(visible_unit_ids) self.notify_unit_visibility_changed() @@ -170,7 +170,7 @@ def _qt_on_item_selection_changed_delete(self): if len(self.table_delete.selectedIndexes()) == 0: return ind = self.table_delete.selectedIndexes()[0].row() - unit_id = self.controller.curation_data["removed_units"][ind] + unit_id = self.controller.curation_data["removed"][ind] self.controller.set_all_unit_visibility_off() # convert to the correct type unit_id = self.controller.unit_ids.dtype.type(unit_id) @@ -332,7 +332,7 @@ def _panel_make_layout(self): def _panel_refresh(self): import pandas as pd # Merged - merged_units = self.controller.curation_data["merge_unit_groups"] + merged_units = [m["unit_ids"] for m in self.controller.curation_data["merges"]] # for visualization, we make all row entries strings merged_units_str = [] @@ -345,7 +345,7 @@ def _panel_refresh(self): self.table_merge.selection = [] ## deleted - removed_units = self.controller.curation_data["removed_units"] + removed_units = self.controller.curation_data["removed"] removed_units = [str(unit_id) for unit_id in removed_units] df = pd.DataFrame({"deleted_unit_id": removed_units}) self.table_delete.value = df diff --git a/spikeinterface_gui/mergeview.py b/spikeinterface_gui/mergeview.py index 8f5fa76..6578ccd 100644 --- a/spikeinterface_gui/mergeview.py +++ b/spikeinterface_gui/mergeview.py @@ -80,7 +80,7 @@ def get_table_data(self, include_deleted=False): unit_ids = list(self.controller.unit_ids) for group_ids in self.proposed_merge_unit_groups: if not include_deleted and self.controller.curation: - deleted_unit_ids = self.controller.curation_data["removed_units"] + deleted_unit_ids = self.controller.curation_data["removed"] if any(unit_id in deleted_unit_ids for unit_id in group_ids): continue diff --git a/spikeinterface_gui/tests/test_mainwindow_qt.py b/spikeinterface_gui/tests/test_mainwindow_qt.py index 14e65d2..7d2611d 100644 --- a/spikeinterface_gui/tests/test_mainwindow_qt.py +++ b/spikeinterface_gui/tests/test_mainwindow_qt.py @@ -110,7 +110,7 @@ def test_launcher(verbose=True): if __name__ == '__main__': if not test_folder.is_dir(): setup_module() - # win = test_mainwindow(start_app=True, verbose=True, curation=True) + win = test_mainwindow(start_app=True, verbose=True, curation=True) # win = test_mainwindow(start_app=True, verbose=True, curation=False) - test_launcher(verbose=True) + # test_launcher(verbose=True) diff --git a/spikeinterface_gui/tests/testingtools.py b/spikeinterface_gui/tests/testingtools.py index 4f1a7a0..92e1218 100644 --- a/spikeinterface_gui/tests/testingtools.py +++ b/spikeinterface_gui/tests/testingtools.py @@ -137,6 +137,7 @@ def make_analyzer_folder(test_folder, case="small", unit_dtype="str"): def make_curation_dict(analyzer): unit_ids = analyzer.unit_ids.tolist() curation_dict = { + "format_version": "2", "unit_ids": unit_ids, "label_definitions": { "quality":{ @@ -153,8 +154,8 @@ def make_curation_dict(analyzer): {'unit_id': unit_ids[2], "putative_type": ["exitatory"]}, {'unit_id': unit_ids[3], "quality": ["noise"], "putative_type": ["inhibitory"]}, ], - "merge_unit_groups": [unit_ids[:3], unit_ids[3:5]], - "removed_units": unit_ids[5:8], + "merges": [{"unit_ids": unit_ids[:3]}, {"unit_ids": unit_ids[3:5]}], + "removed": unit_ids[5:8], } return curation_dict