diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a8465d89d..82a721f47 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,7 @@ repos: tests/resources/.*\.(tsv|txt|testa|testb|train|conllu|json) )$ - repo: https://github.com/pycqa/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort name: isort (python) diff --git a/flair/datasets/__init__.py b/flair/datasets/__init__.py index 0afedeb1d..5263d71e1 100755 --- a/flair/datasets/__init__.py +++ b/flair/datasets/__init__.py @@ -168,6 +168,7 @@ CONLL_03_GERMAN, CONLL_03_SPANISH, CONLL_2000, + FEWNERD, KEYPHRASE_INSPEC, KEYPHRASE_SEMEVAL2010, KEYPHRASE_SEMEVAL2017, @@ -438,6 +439,7 @@ "CONLL_03_GERMAN", "CONLL_03_SPANISH", "CONLL_2000", + "FEWNERD", "KEYPHRASE_INSPEC", "KEYPHRASE_SEMEVAL2010", "KEYPHRASE_SEMEVAL2017", diff --git a/flair/datasets/sequence_labeling.py b/flair/datasets/sequence_labeling.py index 527b038fb..fd5eeea39 100644 --- a/flair/datasets/sequence_labeling.py +++ b/flair/datasets/sequence_labeling.py @@ -1092,6 +1092,79 @@ def __init__( ) +class FEWNERD(ColumnCorpus): + def __init__( + self, + setting: str = "supervised", + **corpusargs, + ): + assert setting in ["supervised", "inter", "intra"] + + base_path = flair.cache_root / "datasets" + self.dataset_name = self.__class__.__name__.lower() + self.data_folder = base_path / self.dataset_name / setting + self.bio_format_data = base_path / self.dataset_name / setting / "bio_format" + + if not self.data_folder.exists(): + self._download(setting=setting) + + if not self.bio_format_data.exists(): + self._generate_splits(setting) + + super(FEWNERD, self).__init__( + self.bio_format_data, + column_format={0: "text", 1: "ner"}, + **corpusargs, + ) + + def _download(self, setting): + _URLs = { + "supervised": "https://cloud.tsinghua.edu.cn/f/09265750ae6340429827/?dl=1", + "intra": "https://cloud.tsinghua.edu.cn/f/a0d3efdebddd4412b07c/?dl=1", + "inter": "https://cloud.tsinghua.edu.cn/f/165693d5e68b43558f9b/?dl=1", + } + + log.info(f"FewNERD ({setting}) dataset not found, downloading.") + dl_path = _URLs[setting] + dl_dir = cached_path(dl_path, Path("datasets") / self.dataset_name / setting) + + if setting not in os.listdir(self.data_folder): + import zipfile + + from tqdm import tqdm + + log.info("FewNERD dataset has not been extracted yet, extracting it now. This might take a while.") + with zipfile.ZipFile(dl_dir, "r") as zip_ref: + for f in tqdm(zip_ref.namelist()): + if f.endswith("/"): + os.makedirs(self.data_folder / f) + else: + zip_ref.extract(f, path=self.data_folder) + + def _generate_splits(self, setting): + log.info( + f"FewNERD splits for {setting} have not been parsed into BIO format, parsing it now. This might take a while." + ) + os.mkdir(self.bio_format_data) + for split in os.listdir(self.data_folder / setting): + with open(self.data_folder / setting / split, "r") as source: + with open(self.bio_format_data / split, "w") as target: + previous_tag = None + for line in source: + if line == "" or line == "\n": + target.write("\n") + else: + token, tag = line.split("\t") + tag = tag.replace("\n", "") + if tag == "O": + target.write(token + "\t" + tag + "\n") + elif previous_tag != tag and tag != "O": + target.write(token + "\t" + "B-" + tag + "\n") + elif previous_tag == tag and tag != "O": + target.write(token + "\t" + "I-" + tag + "\n") + previous_tag = tag + + class BIOSCOPE(ColumnCorpus): def __init__( self,