diff --git a/spikeinterface_gui/waveformview.py b/spikeinterface_gui/waveformview.py index 37bdd09..4922a20 100644 --- a/spikeinterface_gui/waveformview.py +++ b/spikeinterface_gui/waveformview.py @@ -3,7 +3,7 @@ from .view_base import ViewBase - +from functools import partial _wheel_refresh_time = 0.1 @@ -19,7 +19,7 @@ class WaveformView(ViewBase): {'name': 'auto_zoom_on_unit_selection', 'type': 'bool', 'value': True}, {'name': 'show_only_selected_cluster', 'type': 'bool', 'value': True}, {'name': 'plot_limit_for_flatten', 'type': 'bool', 'value': True }, - {'name': 'fillbetween', 'type': 'bool', 'value': True }, + {'name': 'plot_std', 'type': 'bool', 'value': True }, {'name': 'show_channel_id', 'type': 'bool', 'value': False}, {'name': 'sparse_display', 'type': 'bool', 'value' : True }, ] @@ -44,7 +44,7 @@ def __init__(self, controller=None, parent=None, backend="qt"): self.delta_y = np.min(np.diff(unique_y)) else: self.delta_y = 40. # um - self.factor_y = .05 + self.factor_y = .02 self.factor_x = 1.0 espx = self.delta_x / 2.5 for chan_ind, chan_id in enumerate(self.controller.channel_ids): @@ -183,6 +183,14 @@ def _qt_initialize_plot(self): self.viewBox1.widen_narrow.connect(self._qt_widen_narrow) + shortcut_scale_waveforms_up = QT.QShortcut(self.qt_widget) + shortcut_scale_waveforms_up.setKey(QT.QKeySequence("ctrl+=")) + shortcut_scale_waveforms_up.activated.connect(partial(self._qt_gain_zoom, 1.3)) + + shortcut_scale_waveforms_down = QT.QShortcut(self.qt_widget) + shortcut_scale_waveforms_down.setKey(QT.QKeySequence("ctrl+-")) + shortcut_scale_waveforms_down.activated.connect(partial(self._qt_gain_zoom, 1/1.3)) + shortcut_overlap = QT.QShortcut(self.qt_widget) shortcut_overlap.setKey(QT.QKeySequence("ctrl+o")) shortcut_overlap.activated.connect(self.toggle_overlap) @@ -314,7 +322,7 @@ def addSpan(plot): self.plot1.addItem(curve) - if self.settings['fillbetween']: + if self.settings['plot_std']: color2 = QT.QColor(color) color2.setAlpha(self.alpha) curve1 = pg.PlotCurveItem(xvect, template_avg.T.flatten() + template_std.T.flatten(), pen=color2) @@ -387,15 +395,18 @@ def _qt_refresh_mode_geometry(self, dict_visible_units, keep_range): xvectors = self.xvect[common_channel_indexes, :] * self.factor_x xvects = self.get_xvectors_not_overlap(xvectors, len(visible_unit_ids)) + if keep_range is False: + self.factor_y = 0.02 for (xvect, unit_index, unit_id) in zip(xvects, visible_unit_indices, visible_unit_ids): template_avg = self.controller.templates_average[unit_index, :, :][:, common_channel_indexes] - + template_std = self.controller.templates_std[unit_index, :, :][:, common_channel_indexes] + ypos = self.contact_location[common_channel_indexes,1] wf = template_avg wf = wf * self.factor_y * self.delta_y + ypos[None, :] - + connect = np.ones(wf.shape, dtype='bool') connect[0, :] = 0 connect[-1, :] = 0 @@ -404,7 +415,22 @@ def _qt_refresh_mode_geometry(self, dict_visible_units, keep_range): color = self.get_unit_color(unit_id) curve = pg.PlotCurveItem(xvect.flatten(), wf.T.flatten(), pen=pg.mkPen(color, width=2), connect=connect.T.flatten()) + + if self.settings['plot_std'] and (template_std is not None): + + wf_std_p = wf + template_std * self.factor_y * self.delta_y + wf_std_m = wf - template_std * self.factor_y * self.delta_y + + curve_p = pg.PlotCurveItem(xvect.flatten(), wf_std_p.T.flatten(), connect=connect.T.flatten()) + curve_m = pg.PlotCurveItem(xvect.flatten(), wf_std_m.T.flatten(), connect=connect.T.flatten()) + + color2 = QT.QColor(color) + color2.setAlpha(80) + fill = pg.FillBetweenItem(curve1=curve_m, curve2=curve_p, brush=color2) + self.plot1.addItem(fill) + self.plot1.addItem(curve) + if self.settings['show_channel_id']: for chan_ind in common_channel_indexes: @@ -414,16 +440,16 @@ def _qt_refresh_mode_geometry(self, dict_visible_units, keep_range): itemtxt.setFont(QT.QFont('', pointSize=12)) self.plot1.addItem(itemtxt) itemtxt.setPos(x, y) - + if self._x_range is None or not keep_range: - x_margin =50 - y_margin =150 + x_margin = 15 + y_margin = 20 self._x_range = np.min(xvects) - x_margin , np.max(xvects) + x_margin - visible_mask = self.controller.get_units_visibility_mask() - visible_pos = self.controller.unit_positions[visible_mask, :] - self._y1_range = np.min(visible_pos[:,1]) - y_margin , np.max(visible_pos[:,1]) + y_margin - + + channel_positions_y = self.contact_location[common_channel_indexes,1] + self._y1_range = np.min(channel_positions_y) - y_margin , np.max(channel_positions_y) + y_margin + self.plot1.setXRange(*self._x_range, padding = 0.0) self.plot1.setYRange(*self._y1_range, padding = 0.0) @@ -582,7 +608,7 @@ def _panel_refresh(self, keep_range=False): # zoom factor is reset if self.settings["auto_zoom_on_unit_selection"]: self.factor_x = 1.0 - self.factor_y = .05 + self.factor_y = .02 self._panel_refresh_mode_geometry(dict_visible_units, keep_range=keep_range) elif self.mode=='flatten': self._panel_refresh_mode_flatten(dict_visible_units, keep_range=keep_range)