Skip to content

Commit

Permalink
Merge pull request #2645 from flairNLP/refactor_annotations
Browse files Browse the repository at this point in the history
Major refactoring of internal label logic
  • Loading branch information
alanakbik authored Feb 25, 2022
2 parents 0490121 + 6882cb5 commit 016cd52
Show file tree
Hide file tree
Showing 32 changed files with 1,696 additions and 1,692 deletions.
628 changes: 350 additions & 278 deletions flair/data.py

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions flair/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
import os
from abc import abstractmethod
from pathlib import Path
from typing import Callable, Generic, List, Union
from typing import Generic, List, Union

import torch.utils.data.dataloader
from torch.utils.data.dataset import ConcatDataset, Subset

from flair.data import DT, FlairDataset, Sentence, Token, Tokenizer
from flair.data import DT, FlairDataset, Sentence, Tokenizer
from flair.tokenization import SegtokTokenizer, SpaceTokenizer

log = logging.getLogger("flair")
Expand Down Expand Up @@ -105,7 +105,7 @@ class StringDataset(FlairDataset):
def __init__(
self,
texts: Union[str, List[str]],
use_tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SpaceTokenizer(),
use_tokenizer: Union[bool, Tokenizer] = SpaceTokenizer(),
):
"""
Instantiate StringDataset
Expand Down Expand Up @@ -225,7 +225,7 @@ def _parse_document_to_sentence(
self,
text: str,
labels: List[str],
tokenizer: Union[Callable[[str], List[Token]], Tokenizer],
tokenizer: Union[bool, Tokenizer],
):
if self.max_chars_per_doc > 0:
text = text[: self.max_chars_per_doc]
Expand Down
35 changes: 16 additions & 19 deletions flair/datasets/document_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
import logging
import os
from pathlib import Path
from typing import Callable, Dict, List, Union
from typing import Dict, List, Union

import flair
from flair.data import (
Corpus,
DataPair,
FlairDataset,
Sentence,
Token,
Tokenizer,
_iter_dataset,
)
Expand All @@ -37,7 +36,7 @@ def __init__(
truncate_to_max_tokens: int = -1,
truncate_to_max_chars: int = -1,
filter_if_longer_than: int = -1,
tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SegtokTokenizer(),
tokenizer: Union[bool, Tokenizer] = SegtokTokenizer(),
memory_mode: str = "partial",
label_name_map: Dict[str, str] = None,
skip_labels: List[str] = None,
Expand Down Expand Up @@ -140,7 +139,7 @@ def __init__(
truncate_to_max_tokens=-1,
truncate_to_max_chars=-1,
filter_if_longer_than: int = -1,
tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SegtokTokenizer(),
tokenizer: Union[bool, Tokenizer] = SegtokTokenizer(),
memory_mode: str = "partial",
label_name_map: Dict[str, str] = None,
skip_labels: List[str] = None,
Expand Down Expand Up @@ -253,9 +252,7 @@ def __init__(
position = f.tell()
line = f.readline()

def _parse_line_to_sentence(
self, line: str, label_prefix: str, tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer]
):
def _parse_line_to_sentence(self, line: str, label_prefix: str, tokenizer: Union[bool, Tokenizer]):
words = line.split()

labels = []
Expand Down Expand Up @@ -1106,7 +1103,7 @@ class SENTEVAL_CR(ClassificationCorpus):

def __init__(
self,
tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SpaceTokenizer(),
tokenizer: Union[bool, Tokenizer] = SpaceTokenizer(),
memory_mode: str = "full",
**corpusargs,
):
Expand Down Expand Up @@ -1160,7 +1157,7 @@ class SENTEVAL_MR(ClassificationCorpus):

def __init__(
self,
tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SpaceTokenizer(),
tokenizer: Union[bool, Tokenizer] = SpaceTokenizer(),
memory_mode: str = "full",
**corpusargs,
):
Expand Down Expand Up @@ -1214,7 +1211,7 @@ class SENTEVAL_SUBJ(ClassificationCorpus):

def __init__(
self,
tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SpaceTokenizer(),
tokenizer: Union[bool, Tokenizer] = SpaceTokenizer(),
memory_mode: str = "full",
**corpusargs,
):
Expand Down Expand Up @@ -1268,7 +1265,7 @@ class SENTEVAL_MPQA(ClassificationCorpus):

def __init__(
self,
tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SpaceTokenizer(),
tokenizer: Union[bool, Tokenizer] = SpaceTokenizer(),
memory_mode: str = "full",
**corpusargs,
):
Expand Down Expand Up @@ -1322,7 +1319,7 @@ class SENTEVAL_SST_BINARY(ClassificationCorpus):

def __init__(
self,
tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SpaceTokenizer(),
tokenizer: Union[bool, Tokenizer] = SpaceTokenizer(),
memory_mode: str = "full",
**corpusargs,
):
Expand Down Expand Up @@ -1382,7 +1379,7 @@ class SENTEVAL_SST_GRANULAR(ClassificationCorpus):

def __init__(
self,
tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SpaceTokenizer(),
tokenizer: Union[bool, Tokenizer] = SpaceTokenizer(),
memory_mode: str = "full",
**corpusargs,
):
Expand Down Expand Up @@ -1535,7 +1532,7 @@ class GO_EMOTIONS(ClassificationCorpus):
def __init__(
self,
base_path: Union[str, Path] = None,
tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SegtokTokenizer(),
tokenizer: Union[bool, Tokenizer] = SegtokTokenizer(),
memory_mode: str = "partial",
**corpusargs,
):
Expand Down Expand Up @@ -1642,7 +1639,7 @@ class TREC_50(ClassificationCorpus):
def __init__(
self,
base_path: Union[str, Path] = None,
tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SpaceTokenizer(),
tokenizer: Union[bool, Tokenizer] = SpaceTokenizer(),
memory_mode="full",
**corpusargs,
):
Expand Down Expand Up @@ -1704,7 +1701,7 @@ class TREC_6(ClassificationCorpus):
def __init__(
self,
base_path: Union[str, Path] = None,
tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SpaceTokenizer(),
tokenizer: Union[bool, Tokenizer] = SpaceTokenizer(),
memory_mode="full",
**corpusargs,
):
Expand Down Expand Up @@ -1767,7 +1764,7 @@ class YAHOO_ANSWERS(ClassificationCorpus):
def __init__(
self,
base_path: Union[str, Path] = None,
tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SpaceTokenizer(),
tokenizer: Union[bool, Tokenizer] = SpaceTokenizer(),
memory_mode="partial",
**corpusargs,
):
Expand Down Expand Up @@ -1846,7 +1843,7 @@ class GERMEVAL_2018_OFFENSIVE_LANGUAGE(ClassificationCorpus):
def __init__(
self,
base_path: Union[str, Path] = None,
tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SegtokTokenizer(),
tokenizer: Union[bool, Tokenizer] = SegtokTokenizer(),
memory_mode: str = "full",
fine_grained_classes: bool = False,
**corpusargs,
Expand Down Expand Up @@ -1919,7 +1916,7 @@ def __init__(
self,
base_path: Union[str, Path] = None,
memory_mode: str = "full",
tokenizer: Union[bool, Callable[[str], List[Token]], Tokenizer] = SpaceTokenizer(),
tokenizer: Union[bool, Tokenizer] = SpaceTokenizer(),
**corpusargs,
):
"""
Expand Down
29 changes: 8 additions & 21 deletions flair/datasets/sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,7 @@
from torch.utils.data import ConcatDataset, Dataset

import flair
from flair.data import (
Corpus,
FlairDataset,
MultiCorpus,
RelationLabel,
Sentence,
Span,
SpanLabel,
Token,
)
from flair.data import Corpus, FlairDataset, MultiCorpus, Relation, Sentence, Token
from flair.datasets.base import find_train_dev_test_files
from flair.file_utils import cached_path, unpack_file
from flair.models.sequence_tagger_utils.bioes import get_spans_from_bio
Expand Down Expand Up @@ -358,7 +349,7 @@ def _identify_span_columns(self, column_name_map, skip_first_line):
continue

for token in sentence:
if token.get_tag(layer, "O").value != "O" and token.get_tag(layer).value[0:2] not in [
if token.get_label(layer, "O").value != "O" and token.get_label(layer).value[0:2] not in [
"B-",
"I-",
"E-",
Expand Down Expand Up @@ -393,7 +384,7 @@ def _convert_lines_to_sentence(
self, lines, word_level_tag_columns: Dict[int, str], span_level_tag_columns: Optional[Dict[int, str]] = None
):

sentence: Sentence = Sentence()
sentence: Sentence = Sentence(text=[])
token: Optional[Token] = None
filtered_lines = []
comments = []
Expand Down Expand Up @@ -424,13 +415,9 @@ def _convert_lines_to_sentence(
for span_indices, score, label in predicted_spans:
span = sentence[span_indices[0] : span_indices[-1] + 1]
value = self._remap_label(label)
sentence.add_complex_label(
typename=span_level_tag_columns[span_column],
label=SpanLabel(span=span, value=value, score=score),
)
span.add_label(span_level_tag_columns[span_column], value=value, score=score)
except Exception:
pass
# log.warning(f"--\nUnparseable sentence: {''.join(lines)}--\n")

for comment in comments:
# parse relations if they are set
Expand All @@ -444,10 +431,10 @@ def _convert_lines_to_sentence(
tail_end = int(indices[3])
label = indices[4]
# head and tail span indices are 1-indexed and end index is inclusive
head = Span(sentence.tokens[head_start - 1 : head_end])
tail = Span(sentence.tokens[tail_start - 1 : tail_end])

sentence.add_complex_label("relation", RelationLabel(value=label, head=head, tail=tail))
relation = Relation(
first=sentence[head_start - 1 : head_end], second=sentence[tail_start - 1 : tail_end]
)
relation.add_label(typename="relation", value=label)

if len(sentence) > 0:
return sentence
Expand Down
6 changes: 3 additions & 3 deletions flair/datasets/text_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import List, Optional, Union

import flair
from flair.data import Corpus, DataPair, FlairDataset, Sentence, _iter_dataset
from flair.data import Corpus, DataPair, FlairDataset, Sentence, TextPair, _iter_dataset
from flair.datasets.base import find_train_dev_test_files
from flair.file_utils import cached_path, unpack_file, unzip_file

Expand Down Expand Up @@ -180,7 +180,7 @@ def _make_bi_sentence(self, source_line: str, target_line: str):
source_sentence.tokens = source_sentence.tokens[: self.max_tokens_per_doc]
target_sentence.tokens = target_sentence.tokens[: self.max_tokens_per_doc]

return DataPair(source_sentence, target_sentence)
return TextPair(source_sentence, target_sentence)

def __len__(self):
return self.total_sentence_count
Expand Down Expand Up @@ -416,7 +416,7 @@ def _make_data_pair(self, first_element: str, second_element: str, label: str =
first_sentence.tokens = first_sentence.tokens[: self.max_tokens_per_doc]
second_sentence.tokens = second_sentence.tokens[: self.max_tokens_per_doc]

data_pair = DataPair(first_sentence, second_sentence)
data_pair = TextPair(first_sentence, second_sentence)

if label:
data_pair.add_label(typename=self.label_type, value=label)
Expand Down
2 changes: 1 addition & 1 deletion flair/datasets/treebanks.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __getitem__(self, index: int = 0) -> Sentence:

def _read_next_sentence(self, file):
line = file.readline()
sentence: Sentence = Sentence()
sentence: Sentence = Sentence([])

# current token ID
token_idx = 0
Expand Down
25 changes: 8 additions & 17 deletions flair/embeddings/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
if "field" not in self.__dict__ or self.field is None:
word = token.text
else:
word = token.get_tag(self.field).value
word = token.get_label(self.field).value
word_indices.append(self.get_cached_token_index(word))

embeddings = self.embedding(torch.tensor(word_indices, dtype=torch.long, device=self.device))
Expand Down Expand Up @@ -671,8 +671,7 @@ def __init__(
self.chars_per_chunk: int = chars_per_chunk

# embed a dummy sentence to determine embedding_length
dummy_sentence: Sentence = Sentence()
dummy_sentence.add_token(Token("hello"))
dummy_sentence: Sentence = Sentence("hello")
embedded_dummy = self.embed(dummy_sentence)
self.__embedding_length: int = len(embedded_dummy[0][0].get_embedding())

Expand Down Expand Up @@ -975,7 +974,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
if "field" not in self.__dict__ or self.field is None:
word = token.text
else:
word = token.get_tag(self.field).value
word = token.get_label(self.field).value

word_embedding = self.get_cached_vec(word)

Expand Down Expand Up @@ -1039,7 +1038,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
if self.field == "text":
one_hot_sentences = [self.vocab_dictionary.get_idx_for_item(t.text) for t in tokens]
else:
one_hot_sentences = [self.vocab_dictionary.get_idx_for_item(t.get_tag(self.field).value) for t in tokens]
one_hot_sentences = [self.vocab_dictionary.get_idx_for_item(t.get_label(self.field).value) for t in tokens]

one_hot_sentences_tensor = torch.tensor(one_hot_sentences, dtype=torch.long).to(flair.device)

Expand All @@ -1065,7 +1064,7 @@ def from_corpus(cls, corpus: Corpus, field: str = "text", min_freq: int = 3, **k
if field == "text":
most_common = Counter(list(map((lambda t: t.text), tokens))).most_common()
else:
most_common = Counter(list(map((lambda t: t.get_tag(field).value), tokens))).most_common()
most_common = Counter(list(map((lambda t: t.get_label(field).value), tokens))).most_common()

tokens = []
for token, freq in most_common:
Expand Down Expand Up @@ -1206,12 +1205,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))):

if "field" not in self.__dict__ or self.field is None:
word = token.text
else:
word = token.get_tag(self.field).value

word_embedding = self.get_cached_vec(language_code=language_code, word=word)
word_embedding = self.get_cached_vec(language_code=language_code, word=token.text)

token.set_embedding(self.name, word_embedding)

Expand Down Expand Up @@ -1310,10 +1304,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))):

if "field" not in self.__dict__ or self.field is None:
word = token.text
else:
word = token.get_tag(self.field).value
word = token.text

if word.strip() == "":
# empty words get no embedding
Expand Down Expand Up @@ -1410,7 +1401,7 @@ def __init__(
)

# embed a dummy sentence to determine embedding_length
dummy_sentence: Sentence = Sentence()
dummy_sentence: Sentence = Sentence([])
dummy_sentence.add_token(Token("hello"))
embedded_dummy = self.embed(dummy_sentence)
self.__embedding_length: int = len(embedded_dummy[0][0].get_embedding())
Expand Down
Loading

0 comments on commit 016cd52

Please sign in to comment.