diff --git a/flair/data.py b/flair/data.py index 1ab3b29f4..8a13ae8bd 100644 --- a/flair/data.py +++ b/flair/data.py @@ -6,7 +6,7 @@ from collections import Counter, defaultdict from operator import itemgetter from pathlib import Path -from typing import Dict, Iterable, List, NamedTuple, Optional, Union, cast +from typing import Dict, Iterable, List, NamedTuple, Optional, Tuple, Union, cast import torch from deprecated.sphinx import deprecated @@ -1301,6 +1301,7 @@ def __init__( test: Optional[Dataset[T_co]] = None, name: str = "corpus", sample_missing_splits: Union[bool, str] = True, + random_seed: Optional[int] = None, ) -> None: # set name self.name: str = name @@ -1314,7 +1315,7 @@ def __init__( test_portion = 0.1 train_length = _len_dataset(train) test_size: int = round(train_length * test_portion) - test, train = randomly_split_into_two_datasets(train, test_size) + test, train = randomly_split_into_two_datasets(train, test_size, random_seed) log.warning( "No test split found. Using %.0f%% (i.e. %d samples) of the train split as test data", test_portion, @@ -1326,7 +1327,7 @@ def __init__( dev_portion = 0.1 train_length = _len_dataset(train) dev_size: int = round(train_length * dev_portion) - dev, train = randomly_split_into_two_datasets(train, dev_size) + dev, train = randomly_split_into_two_datasets(train, dev_size, random_seed) log.warning( "No dev split found. Using %.0f%% (i.e. %d samples) of the train split as dev data", dev_portion, @@ -1353,18 +1354,20 @@ def test(self) -> Optional[Dataset[T_co]]: def downsample( self, percentage: float = 0.1, - downsample_train=True, - downsample_dev=True, - downsample_test=True, + downsample_train: bool = True, + downsample_dev: bool = True, + downsample_test: bool = True, + random_seed: Optional[int] = None, ): + """Reduce all datasets in corpus proportionally to the given percentage.""" if downsample_train and self._train is not None: - self._train = self._downsample_to_proportion(self._train, percentage) + self._train = self._downsample_to_proportion(self._train, percentage, random_seed) if downsample_dev and self._dev is not None: - self._dev = self._downsample_to_proportion(self._dev, percentage) + self._dev = self._downsample_to_proportion(self._dev, percentage, random_seed) if downsample_test and self._test is not None: - self._test = self._downsample_to_proportion(self._test, percentage) + self._test = self._downsample_to_proportion(self._test, percentage, random_seed) return self @@ -1461,9 +1464,9 @@ def _get_all_tokens(self) -> List[str]: return [t.text for t in tokens] @staticmethod - def _downsample_to_proportion(dataset: Dataset, proportion: float): + def _downsample_to_proportion(dataset: Dataset, proportion: float, random_seed: Optional[int] = None) -> Subset: sampled_size: int = round(_len_dataset(dataset) * proportion) - splits = randomly_split_into_two_datasets(dataset, sampled_size) + splits = randomly_split_into_two_datasets(dataset, sampled_size, random_seed=random_seed) return splits[0] def obtain_statistics(self, label_type: Optional[str] = None, pretty_print: bool = True) -> Union[dict, str]: @@ -1879,11 +1882,21 @@ def iob2(tags): return True -def randomly_split_into_two_datasets(dataset, length_of_first): +def randomly_split_into_two_datasets( + dataset: Dataset, length_of_first: int, random_seed: Optional[int] = None +) -> Tuple[Subset, Subset]: + """Shuffles a dataset and splits into two subsets. + + The length of the first is specified and the remaining samples go into the second subset. + """ import random - indices = list(range(len(dataset))) - random.shuffle(indices) + indices = list(range(_len_dataset(dataset))) + if random_seed is None: + random.shuffle(indices) + else: + random_generator = random.Random(random_seed) + random_generator.shuffle(indices) first_dataset = indices[:length_of_first] second_dataset = indices[length_of_first:]