Skip to content

Commit

Permalink
Implement crossdecoding and patch_data_func in braindecode_utils
Browse files Browse the repository at this point in the history
  • Loading branch information
smathot committed Nov 11, 2022
1 parent 7333cf5 commit 294f39b
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 33 deletions.
2 changes: 1 addition & 1 deletion eeg_eyetracking_parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@
tfr_to_surface


__version__ = '0.10.0'
__version__ = '0.11.0'
logger = logging.getLogger('eeg_eyetracking_parser')
logger.info(f'eeg_eyetracking_parser {__version__}')
80 changes: 48 additions & 32 deletions eeg_eyetracking_parser/braindecode_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@

@fnc.memoize(persistent=True)
def decode_subject(read_subject_kwargs, factors, epochs_kwargs, trigger,
epochs_query='practice == "no"', epochs=4, lesions=None,
window_size=200, window_stride=1, n_fold=4,
pretrained_clf=None):
epochs_query='practice == "no"', epochs=4, window_size=200,
window_stride=1, n_fold=4, crossdecode_factors=None,
patch_data_func=None):
"""The main entry point for decoding a subject's data.
Parameters
Expand All @@ -45,11 +45,10 @@ def decode_subject(read_subject_kwargs, factors, epochs_kwargs, trigger,
A dict with keyword arguments that are passed to eet.read_subject() to
load the data. Additional preprocessing as specified in
`preprocess_raw()` is applied afterwards.
factors: list of str
A list of factors that should be decoded. Factors should be str and
match column names in the metadata. If there is more than one factor,
each factor should have two levels.
epochs_kwargs: dict
factors: str or list of str
A factor or list of factors that should be decoded. Factors should be
str and match column names in the metadata.
epochs_kwargs: dict, optional
A dict with keyword arguments that are passed to mne.Epochs() to
extract the to-be-decoded epoch.
trigger: int
Expand All @@ -62,20 +61,27 @@ def decode_subject(read_subject_kwargs, factors, epochs_kwargs, trigger,
epochs: int, optional
The number of training epochs, i.e. the number of times that the data
is fed into the model. This should be at least 2.
lesions: list of tuple or str
A list of time windows or electrode names to be set to 0 during
testing. A separate prediction is made for each lesion. Time windows
are (start, end) tuples in sample units. Electrode names are strings.
window_size_samples: int
window_size_samples: int, optional
The length of the window to sample from the Epochs object. This should
be slightly shorter than the actual Epochs to allow for jittered
samples to be taken from the purpose of 'cropped decoding'.
window_stride_samples: int
window_stride_samples: int, optional
The number of samples to jitter around the window for the purpose of
cropped decoding.
n_fold: int
n_fold: int, optional
The total number of splits (or folds). This should be at least 2.
crossdecode_factors: str or list of str, optional
A factor or list of factors that should be decoded during tester. If
provided, the classifier is trained using the factors specified in
`factors` and tested using the factors specified in
`crossdecode_factors`. In other words, specifying this keyword allow
for crossdecoding.
patch_data_func: callable or None, optional
If provided, this should be a function that accepts a tuple of
`(raw, events, metadata)` as returned by `read_subject()` and also
returns a tuple of `(raw, events, metadata)`. This function can modify
aspects of the data before decoding is applied.
Returns
-------
DataMatrix
Expand All @@ -95,16 +101,21 @@ def decode_subject(read_subject_kwargs, factors, epochs_kwargs, trigger,
raise ValueError('n_fold should >= 2')
dataset, labels, metadata = read_decode_dataset(
read_subject_kwargs, factors, epochs_kwargs, trigger, epochs_query,
window_size=window_size, window_stride=window_stride)
window_size=window_size, window_stride=window_stride,
patch_data_func=patch_data_func)
if crossdecode_factors is not None:
cd_dataset, labels, metadata = read_decode_dataset(
read_subject_kwargs, crossdecode_factors, epochs_kwargs, trigger,
epochs_query, window_size=window_size, window_stride=window_stride,
patch_data_func=patch_data_func)
n_conditions = len(labels)
predictions = DataMatrix(length=0)
for fold in range(n_fold):
train_data, test_data = _split_dataset(
dataset, fold=fold, n_fold=n_fold)
if pretrained_clf is None:
clf = train(train_data, test_data, epochs=epochs)
else:
clf = pretrained_clf
train_data, test_data = _split_dataset(dataset, fold=fold,
n_fold=n_fold)
if crossdecode_factors is not None:
_, test_data = _split_dataset(cd_dataset, fold=fold, n_fold=n_fold)
clf = train(train_data, test_data, epochs=epochs)
# We can unbalance the data after training to save time and to make the
# cell counts match again
_unbalance_dataset(test_data)
Expand Down Expand Up @@ -142,12 +153,15 @@ def decode_subject(read_subject_kwargs, factors, epochs_kwargs, trigger,


def read_decode_dataset(read_subject_kwargs, factors, epochs_kwargs, trigger,
epochs_query, lesion=None, window_size=200,
window_stride=1):
epochs_query='practice == "no"', lesion=None,
window_size=200, window_stride=1,
patch_data_func=None):
"""Reads a dataset and converts it to a format that is suitable for
braindecode.
"""
raw, events, metadata = read_subject(**read_subject_kwargs)
if patch_data_func is not None:
raw, events, metadata = patch_data_func(raw, events, metadata)
_preprocess_raw(raw)
epochs = mne.Epochs(raw, epoch_trigger(events, trigger),
metadata=metadata, **epochs_kwargs)
Expand Down Expand Up @@ -315,24 +329,26 @@ def _preprocess_raw(raw, l_freq=4, h_freq=30, factor_new=1e-3,


def _split_epochs(epochs, metadata, factors):
"""Splits an Epochs object based on several factors, which should
"""Splits an Epochs object based on several factors, which should
correspond to columns in the metadata.
Parameters
----------
epochs: Epochs
The Epochs object to split
metadata: DataFrame
The metadata that belongs to epochs
factors: list
A list of factors by which the data should be split.
factors: str or list
A factor or list of factors by which the data should be split.
Returns
-------
tuple
A tuple in which the first element is a list of Epochs objects, and the
second element a list of str labels that define the Epochs objects.
"""
if isinstance(factors, str):
factors = [factors]
subsets = []
labels = []
for code, values in enumerate(
Expand Down Expand Up @@ -451,8 +467,8 @@ def _build_dataset(epochs, metadata, factors, window_size_samples,
The Epochs object to split
metadata: DataFrame
The metadata that belongs to epochs
factors: list
A list of factors by which the data should be split
factors: str or list
A factor or list of factors by which the data should be split.
window_size_samples: int
The length of the window to sample from the Epochs object. This should
be slightly shorter than the actual Epochs to allow for jittered
Expand Down

0 comments on commit 294f39b

Please sign in to comment.