Skip to content

Commit

Permalink
Merge pull request #2333 from flairNLP/relation_classification_script
Browse files Browse the repository at this point in the history
Relation Extraction (RE) and refactoring of interfaces
  • Loading branch information
alanakbik authored Jul 9, 2021
2 parents 042f71d + ff6e1ef commit 8a2946e
Show file tree
Hide file tree
Showing 30 changed files with 1,874 additions and 1,217 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ wheels/
MANIFEST

.idea/
.vscode/

# PyInstaller
# Usually these files are written by a python script from a template
Expand Down
112 changes: 94 additions & 18 deletions flair/data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch, flair
import logging
import re
import ast

from abc import abstractmethod, ABC

Expand Down Expand Up @@ -176,6 +177,49 @@ def __str__(self):
def __repr__(self):
return f"{self._value} ({round(self._score, 4)})"

@property
def identifier(self):
return ""


class SpanLabel(Label):
def __init__(self, span, value: str, score: float = 1.0):
super().__init__(value, score)
self.span = span

def __str__(self):
return f"{self._value} [{self.span.id_text}] ({round(self._score, 4)})"

def __repr__(self):
return f"{self._value} [{self.span.id_text}] ({round(self._score, 4)})"

def __len__(self):
return len(self.span)

@property
def identifier(self):
return f"{self.span.id_text}"


class RelationLabel(Label):
def __init__(self, head, tail, value: str, score: float = 1.0):
super().__init__(value, score)
self.head = head
self.tail = tail

def __str__(self):
return f"{self._value} [{self.head.id_text} -> {self.tail.id_text}] ({round(self._score, 4)})"

def __repr__(self):
return f"{self._value} from {self.head.id_text} -> {self.tail.id_text} ({round(self._score, 4)})"

def __len__(self):
return len(self.head) + len(self.tail)

@property
def identifier(self):
return f"{self.head.id_text} -> {self.tail.id_text}"


class DataPoint:
"""
Expand All @@ -201,29 +245,37 @@ def to(self, device: str, pin_memory: bool = False):
def clear_embeddings(self, embedding_names: List[str] = None):
pass

def add_label(self, label_type: str, value: str, score: float = 1.):
def add_label(self, typename: str, value: str, score: float = 1.):

if label_type not in self.annotation_layers:
self.annotation_layers[label_type] = [Label(value, score)]
if typename not in self.annotation_layers:
self.annotation_layers[typename] = [Label(value, score)]
else:
self.annotation_layers[label_type].append(Label(value, score))
self.annotation_layers[typename].append(Label(value, score))

return self

def set_label(self, label_type: str, value: str, score: float = 1.):
self.annotation_layers[label_type] = [Label(value, score)]
def add_complex_label(self, typename: str, label: Label):

if typename not in self.annotation_layers:
self.annotation_layers[typename] = [label]
else:
self.annotation_layers[typename].append(label)

return self

def set_label(self, typename: str, value: str, score: float = 1.):
self.annotation_layers[typename] = [Label(value, score)]
return self

def remove_labels(self, label_type: str):
if label_type in self.annotation_layers.keys():
del self.annotation_layers[label_type]
def remove_labels(self, typename: str):
if typename in self.annotation_layers.keys():
del self.annotation_layers[typename]

def get_labels(self, label_type: str = None):
if label_type is None:
def get_labels(self, typename: str = None):
if typename is None:
return self.labels

return self.annotation_layers[label_type] if label_type in self.annotation_layers else []
return self.annotation_layers[typename] if typename in self.annotation_layers else []

@property
def labels(self) -> List[Label]:
Expand Down Expand Up @@ -418,7 +470,7 @@ def to_original_text(self) -> str:
pos += len(t.text)

return str

def to_plain_string(self):
plain = ""
for token in self.tokens:
Expand All @@ -438,16 +490,20 @@ def to_dict(self):
def __str__(self) -> str:
ids = ",".join([str(t.idx) for t in self.tokens])
label_string = " ".join([str(label) for label in self.labels])
labels = f' [− Labels: {label_string}]' if self.labels is not None else ""
labels = f' [− Labels: {label_string}]' if self.labels else ""
return (
'Span [{}]: "{}"{}'.format(ids, self.text, labels)
)

@property
def id_text(self) -> str:
return f"{' '.join([t.text for t in self.tokens])} ({','.join([str(t.idx) for t in self.tokens])})"

def __repr__(self) -> str:
ids = ",".join([str(t.idx) for t in self.tokens])
return (
'<{}-span ({}): "{}">'.format(self.tag, ids, self.text)
if self.tag is not None
if len(self.labels) > 0
else '<span ({}): "{}">'.format(ids, self.text)
)

Expand All @@ -468,6 +524,10 @@ def tag(self):
def score(self):
return self.labels[0].score

@property
def position_string(self):
return '-'.join([str(token.idx) for token in self])


class Tokenizer(ABC):
r"""An abstract class representing a :class:`Tokenizer`.
Expand Down Expand Up @@ -594,6 +654,8 @@ def __init__(
# some sentences represent a document boundary (but most do not)
self.is_document_boundary: bool = False

self.relations: List[Relation] = []

def get_token(self, token_id: int) -> Token:
for token in self.tokens:
if token.idx == token_id:
Expand Down Expand Up @@ -669,7 +731,7 @@ def _add_spans_internal(self, spans: List[Span], label_type: str, min_score):
if span_score > min_score:
span = Span(current_span)
span.add_label(
label_type=label_type,
typename=label_type,
value=sorted(tags.items(), key=lambda k_v: k_v[1], reverse=True)[0][0],
score=span_score)
spans.append(span)
Expand All @@ -691,7 +753,7 @@ def _add_spans_internal(self, spans: List[Span], label_type: str, min_score):
if span_score > min_score:
span = Span(current_span)
span.add_label(
label_type=label_type,
typename=label_type,
value=sorted(tags.items(), key=lambda k_v: k_v[1], reverse=True)[0][0],
score=span_score)
spans.append(span)
Expand Down Expand Up @@ -990,6 +1052,19 @@ def is_context_set(self) -> bool:
"""
return '_previous_sentence' in self.__dict__.keys() or '_position_in_dataset' in self.__dict__.keys()

def get_labels(self, label_type: str = None):

# TODO: crude hack - replace with something better
if label_type:
spans = self.get_spans(label_type)
for span in spans:
self.add_complex_label(label_type, label=SpanLabel(span, span.tag, span.score))

if label_type is None:
return self.labels

return self.annotation_layers[label_type] if label_type in self.annotation_layers else []


class Image(DataPoint):

Expand Down Expand Up @@ -1321,6 +1396,7 @@ def make_label_dictionary(self, label_type: str = None) -> Dictionary:
if isinstance(sentence, Sentence):
for token in sentence.tokens:
for label in token.get_labels(label_type):
# print(label)
label_dictionary.add_item(label.value)

if not label_dictionary.multi_label:
Expand Down Expand Up @@ -1442,4 +1518,4 @@ def randomly_split_into_two_datasets(dataset, length_of_first):
first_dataset.sort()
second_dataset.sort()

return [Subset(dataset, first_dataset), Subset(dataset, second_dataset)]
return [Subset(dataset, first_dataset), Subset(dataset, second_dataset)]
5 changes: 4 additions & 1 deletion flair/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,4 +256,7 @@
from .biomedical import BIOBERT_SPECIES_S800
from .biomedical import BIOBERT_GENE_BC2GM
from .biomedical import BIOBERT_GENE_JNLPBA
from.treebanks import UD_LATIN
from .treebanks import UD_LATIN

# Expose all relation extraction datasets
from .relation_extraction import SEMEVAL_2010_TASK_8
Loading

0 comments on commit 8a2946e

Please sign in to comment.