diff --git a/sima/motion/dftreg.py b/sima/motion/dftreg.py index f0962e4..ba2b14c 100644 --- a/sima/motion/dftreg.py +++ b/sima/motion/dftreg.py @@ -36,6 +36,7 @@ import numpy as np from scipy.ndimage.interpolation import shift import time +import warnings from . import motion try: from pyfftw.interfaces.numpy_fft import fftn, ifftn @@ -123,11 +124,13 @@ def _estimate(self, dataset): displacements = [] for sequence in dataset: - num_planes = sequence.shape[1] - num_channels = sequence.shape[4] + num_frames, num_planes, _, _, num_channels = sequence.shape if num_channels > 1: - raise NotImplementedError("Error: only one colour channel \ - can be used for DFT motion correction. Using channel 1.") + warnings.warn("Warning: only one colour channel \ + can be used for DFT motion correction. Using channel 0.") + + # get results into a shape sima likes + frame_shifts = np.zeros([num_frames, num_planes, 2]) for plane_idx in range(num_planes): # load into memory... need to pass numpy array to dftreg. @@ -172,8 +175,7 @@ def _estimate(self, dataset): else: dy, dx = output - # get results into a shape sima likes - frame_shifts = np.zeros([len(frames), num_planes, 2]) + # add plane shift info for idx, frame in enumerate(sequence): frame_shifts[idx, plane_idx] = [dy[idx], dx[idx]] displacements.append(frame_shifts) diff --git a/sima/motion/motion.py b/sima/motion/motion.py index 4d667e9..a6d7e63 100644 --- a/sima/motion/motion.py +++ b/sima/motion/motion.py @@ -121,6 +121,11 @@ def correct(self, dataset, savedir, channel_names=None, info=None, else: mc_sequences = sequences displacements = self.estimate(sima.ImagingDataset(mc_sequences, None)) + + # enforce integer displacements + displacements = [d if issubclass(d.dtype.type, np.integer) else \ + d.round().astype(np.int64) for d in displacements] + disp_dim = displacements[0].shape[-1] max_disp = np.ceil( np.max(list(it.chain.from_iterable(d.reshape(-1, disp_dim) @@ -188,7 +193,17 @@ def _estimate(self, dataset): downsampled_dataset) displacements = [] for d_disps in downsampled_displacements: - disps = np.repeat(d_disps, 2, axis=2) # Repeat the displacements + if d_disps.ndim == 3: # whole frame displacements + if self._offset == 0: + disps = d_disps[:] + disps[..., 0] *= 2 # multiply y-shifts by 2 + displacements.append(disps) + continue + else: # duplicate displacements for all rows + disps = np.moveaxis(np.stack( + [d_disps] * dataset.frame_shape[0]), 0, 2) + else: # line-by-line displacements + disps = np.repeat(d_disps, 2, axis=2) # Repeat the displacements disps[:, :, :, 0] *= 2 # multiply y-shifts by 2 disps[:, :, 1::2, -1] += self._offset # shift even rows by offset displacements.append(disps)