Skip to content

Implement splitting functionality (ScatterViews, CurationView, CorrelogramsView) #159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 161 additions & 40 deletions spikeinterface_gui/basescatterview.py

Large diffs are not rendered by default.

129 changes: 110 additions & 19 deletions spikeinterface_gui/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
import spikeinterface.qualitymetrics
from spikeinterface.core.sorting_tools import spike_vector_to_indices
from spikeinterface.core.core_tools import check_json
from spikeinterface.curation import validate_curation_dict
from spikeinterface.widgets.utils import make_units_table_from_analyzer

from .curation_tools import adding_group, default_label_definitions, empty_curation_data
from .curation_tools import add_merge, default_label_definitions, empty_curation_data

spike_dtype =[('sample_index', 'int64'), ('unit_index', 'int64'),
('channel_index', 'int64'), ('segment_index', 'int64'),
Expand Down Expand Up @@ -260,13 +261,14 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save

spike_vector2 = self.analyzer.sorting.to_spike_vector(concatenated=False)
# this is dict of list because per segment spike_indices[segment_index][unit_id]
spike_indices_abs = spike_vector_to_indices(spike_vector2, unit_ids, absolute_index=True)
spike_indices = spike_vector_to_indices(spike_vector2, unit_ids)
# this is flatten
spike_per_seg = [s.size for s in spike_vector2]
# dict[unit_id] -> all indices for this unit across segments
self._spike_index_by_units = {}
# dict[seg_index][unit_id] -> all indices for this unit for one segment
self._spike_index_by_segment_and_units = spike_indices
self._spike_index_by_segment_and_units = spike_indices_abs
for unit_id in unit_ids:
inds = []
for seg_ind in range(num_seg):
Expand Down Expand Up @@ -319,7 +321,23 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save
if curation_data is None:
self.curation_data = empty_curation_data.copy()
else:
self.curation_data = curation_data
# validate the curation data
format_version = curation_data.get("format_version", None)
# assume version 2 if not present
if format_version is None:
raise ValueError("Curation data format version is missing and is required in the curation data.")
try:
validate_curation_dict(curation_data)
self.curation_data = curation_data
except Exception as e:
print(f"Invalid curation data. Initializing with empty curation data.\nError: {e}")
self.curation_data = empty_curation_data.copy()
if curation_data.get("merges") is None:
curation_data["merges"] = []
if curation_data.get("splits") is None:
curation_data["splits"] = []
if curation_data.get("removed") is None:
curation_data["removed"] = []

self.has_default_quality_labels = False
if "label_definitions" not in self.curation_data:
Expand All @@ -337,6 +355,8 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save
print('Curation quality labels are the default ones')
self.has_default_quality_labels = True

# this is used to store the active split unit
self.active_split = None

def check_is_view_possible(self, view_name):
from .viewlist import possible_class_views
Expand Down Expand Up @@ -442,6 +462,13 @@ def set_visible_unit_ids(self, visible_unit_ids):
if len(visible_unit_ids) > lim:
visible_unit_ids = visible_unit_ids[:lim]
self._visible_unit_ids = list(visible_unit_ids)
self.active_split = None
if len(visible_unit_ids) == 1 and self.curation:
# check if unit is split
for split in self.curation_data['splits']:
if visible_unit_ids[0] == split['unit_id']:
self.active_split = split
break

def get_visible_unit_ids(self):
"""Get list of visible unit_ids"""
Expand Down Expand Up @@ -506,10 +533,21 @@ def get_indices_spike_visible(self):
return self._spike_visible_indices

def get_indices_spike_selected(self):
if self.active_split is not None:
# select the splitted spikes in the active split
split_unit_id = self.active_split['unit_id']
spike_inds = self.get_spike_indices(split_unit_id, seg_index=None)
split_indices = self.active_split['indices']
self._spike_selected_indices = np.array(spike_inds[split_indices], dtype='int64')
return self._spike_selected_indices

def set_indices_spike_selected(self, inds):
self._spike_selected_indices = np.array(inds)
if len(self._spike_selected_indices) == 1:
# set time info
segment_index = self.spikes['segment_index'][self._spike_selected_indices[0]]
sample_index = self.spikes['sample_index'][self._spike_selected_indices[0]]
self.set_time(time=sample_index / self.sampling_frequency, segment_index=segment_index)

def get_spike_indices(self, unit_id, seg_index=None):
if seg_index is None:
Expand Down Expand Up @@ -668,7 +706,7 @@ def curation_can_be_saved(self):

def construct_final_curation(self):
d = dict()
d["format_version"] = "1"
d["format_version"] = "2"
d["unit_ids"] = self.unit_ids.tolist()
d.update(self.curation_data.copy())
return d
Expand Down Expand Up @@ -699,14 +737,14 @@ def make_manual_delete_if_possible(self, removed_unit_ids):
if not self.curation:
return

all_merged_units = sum(self.curation_data["merge_unit_groups"], [])
all_merged_units = sum([m["unit_ids"] for m in self.curation_data["merges"]], [])
for unit_id in removed_unit_ids:
if unit_id in self.curation_data["removed_units"]:
if unit_id in self.curation_data["removed"]:
continue
# TODO: check if unit is already in a merge group
if unit_id in all_merged_units:
continue
self.curation_data["removed_units"].append(unit_id)
self.curation_data["removed"].append(unit_id)
if self.verbose:
print(f"Unit {unit_id} is removed from the curation data")

Expand All @@ -718,10 +756,10 @@ def make_manual_restore(self, restore_unit_ids):
return

for unit_id in restore_unit_ids:
if unit_id in self.curation_data["removed_units"]:
if unit_id in self.curation_data["removed"]:
if self.verbose:
print(f"Unit {unit_id} is restored from the curation data")
self.curation_data["removed_units"].remove(unit_id)
self.curation_data["removed"].remove(unit_id)

def make_manual_merge_if_possible(self, merge_unit_ids):
"""
Expand All @@ -740,22 +778,75 @@ def make_manual_merge_if_possible(self, merge_unit_ids):
return False

for unit_id in merge_unit_ids:
if unit_id in self.curation_data["removed_units"]:
if unit_id in self.curation_data["removed"]:
return False

new_merges = add_merge(self.curation_data["merges"], merge_unit_ids)
self.curation_data["merges"] = new_merges
if self.verbose:
print(f"Merged unit group: {[str(u) for u in merge_unit_ids]}")
return True

def make_manual_split_if_possible(self, unit_id, indices):
"""
Check if the a unit_id can be split into a new split in the curation_data.

If unit_id is already in the removed list then the split is skipped.
If unit_id is already in some other split then the split is skipped.
"""
if not self.curation:
return False

if unit_id in self.curation_data["removed"]:
return False

# check if unit_id is already in a split
for split in self.curation_data["splits"]:
if split["unit_id"] == unit_id:
return False
merged_groups = adding_group(self.curation_data["merge_unit_groups"], merge_unit_ids)
self.curation_data["merge_unit_groups"] = merged_groups

new_split = {
"unit_id": unit_id,
"mode": "indices",
"indices": indices
}
self.curation_data["splits"].append(new_split)
if self.verbose:
print(f"Merged unit group: {merge_unit_ids}")
print(f"Split unit {unit_id} with {len(indices)} spikes")
return True

def make_manual_restore_merge(self, merge_group_indices):
def make_manual_restore_merge(self, merge_indices):
if not self.curation:
return
for merge_index in merge_indices:
if self.verbose:
print(f"Unmerged {self.curation_data['merges'][merge_index]['unit_ids']}")
self.curation_data["merges"].pop(merge_index)

def make_manual_restore_split(self, split_indices):
if not self.curation:
return
merge_groups_to_remove = [self.curation_data["merge_unit_groups"][merge_group_index] for merge_group_index in merge_group_indices]
for merge_group in merge_groups_to_remove:
for split_index in split_indices:
if self.verbose:
print(f"Unmerged merge group {merge_group}")
self.curation_data["merge_unit_groups"].remove(merge_group)
print(f"Unsplitting {self.curation_data['splits'][split_index]['unit_id']}")
self.curation_data["splits"].pop(split_index)

def set_active_split_unit(self, unit_id):
"""
Set the active split unit_id.
This is used to set the label for the split unit.
"""
if not self.curation:
return
if unit_id is None:
self.active_split = None
else:
if unit_id in self.curation_data["removed"]:
print(f"Unit {unit_id} is removed, cannot set as active split unit")
return
active_split = [s for s in self.curation_data["splits"] if s["unit_id"] == unit_id]
if len(active_split) == 1:
self.active_split = active_split[0]

def get_curation_label_definitions(self):
# give only label definition with exclusive
Expand Down
Loading