Skip to content

Commit

Permalink
Merge branch 'master' into flairNLPGH-2117-more-flexibility-on-main-m…
Browse files Browse the repository at this point in the history
…etric

# Conflicts:
#	flair/trainers/trainer.py
  • Loading branch information
MLDLMFZ committed Mar 16, 2021
2 parents cace03f + 339fe07 commit f4bee8c
Show file tree
Hide file tree
Showing 14 changed files with 814 additions and 131 deletions.
2 changes: 1 addition & 1 deletion flair/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import logging.config

__version__ = "0.8"
__version__ = "0.8.1"

logging.config.dictConfig(
{
Expand Down
3 changes: 3 additions & 0 deletions flair/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
from .sequence_labeling import CONLL_03
from .sequence_labeling import CONLL_03_GERMAN
from .sequence_labeling import CONLL_03_DUTCH
from .sequence_labeling import ICELANDIC_NER
from .sequence_labeling import CONLL_03_SPANISH
from .sequence_labeling import CONLL_2000
from .sequence_labeling import DANE
from .sequence_labeling import EUROPARL_NER_GERMAN
from .sequence_labeling import GERMEVAL_14
from .sequence_labeling import INSPEC
from .sequence_labeling import JAPANESE_NER
from .sequence_labeling import LER_GERMAN
from .sequence_labeling import MIT_MOVIE_NER_SIMPLE
from .sequence_labeling import MIT_MOVIE_NER_COMPLEX
Expand Down Expand Up @@ -56,6 +58,7 @@
from .sequence_labeling import WSD_UFSAC
from .sequence_labeling import WNUT_2020_NER
from .sequence_labeling import XTREME
from .sequence_labeling import REDDIT_EL_GOLD

# Expose all document classification datasets
from .document_classification import ClassificationCorpus
Expand Down
86 changes: 52 additions & 34 deletions flair/datasets/document_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
Corpus,
Token,
FlairDataset,
Tokenizer
Tokenizer, DataPair
)
from flair.tokenization import SegtokTokenizer, SpaceTokenizer
from flair.datasets.base import find_train_dev_test_files
Expand Down Expand Up @@ -454,9 +454,12 @@ def __init__(

# most data sets have the token text in the first column, if not, pass 'text' as column
self.text_columns: List[int] = []
self.pair_columns: List[int] = []
for column in column_name_map:
if column_name_map[column] == "text":
self.text_columns.append(column)
if column_name_map[column] == "pair":
self.pair_columns.append(column)

with open(self.path_to_file, encoding=encoding) as csv_file:

Expand Down Expand Up @@ -488,33 +491,61 @@ def __init__(

if self.in_memory:

text = " ".join(
[row[text_column] for text_column in self.text_columns]
)

if self.max_chars_per_doc > 0:
text = text[: self.max_chars_per_doc]

sentence = Sentence(text, use_tokenizer=self.tokenizer)

for column in self.column_name_map:
column_value = row[column]
if (
self.column_name_map[column].startswith("label")
and column_value
):
if column_value != self.no_class_label:
sentence.add_label(label_type, column_value)
sentence = self._make_labeled_data_point(row)

if 0 < self.max_tokens_per_doc < len(sentence):
sentence.tokens = sentence.tokens[: self.max_tokens_per_doc]
self.sentences.append(sentence)

else:
self.raw_data.append(row)

self.total_sentence_count += 1

def _make_labeled_data_point(self, row):

# make sentence from text (and filter for length)
text = " ".join(
[row[text_column] for text_column in self.text_columns]
)

if self.max_chars_per_doc > 0:
text = text[: self.max_chars_per_doc]

sentence = Sentence(text, use_tokenizer=self.tokenizer)

if 0 < self.max_tokens_per_doc < len(sentence):
sentence.tokens = sentence.tokens[: self.max_tokens_per_doc]

# if a pair column is defined, make a sentence pair object
if len(self.pair_columns) > 0:

text = " ".join(
[row[pair_column] for pair_column in self.pair_columns]
)

if self.max_chars_per_doc > 0:
text = text[: self.max_chars_per_doc]

pair = Sentence(text, use_tokenizer=self.tokenizer)

if 0 < self.max_tokens_per_doc < len(sentence):
pair.tokens = pair.tokens[: self.max_tokens_per_doc]

data_point = DataPair(first=sentence, second=pair)

else:
data_point = sentence

for column in self.column_name_map:
column_value = row[column]
if (
self.column_name_map[column].startswith("label")
and column_value
):
if column_value != self.no_class_label:
data_point.add_label(self.label_type, column_value)

return data_point

def is_in_memory(self) -> bool:
return self.in_memory

Expand All @@ -527,20 +558,7 @@ def __getitem__(self, index: int = 0) -> Sentence:
else:
row = self.raw_data[index]

text = " ".join([row[text_column] for text_column in self.text_columns])

if self.max_chars_per_doc > 0:
text = text[: self.max_chars_per_doc]

sentence = Sentence(text, use_tokenizer=self.tokenizer)
for column in self.column_name_map:
column_value = row[column]
if self.column_name_map[column].startswith("label") and column_value:
if column_value != self.no_class_label:
sentence.add_label(self.label_type, column_value)

if 0 < self.max_tokens_per_doc < len(sentence):
sentence.tokens = sentence.tokens[: self.max_tokens_per_doc]
sentence = self._make_labeled_data_point(row)

return sentence

Expand Down
Loading

0 comments on commit f4bee8c

Please sign in to comment.