Skip to content

Commit

Permalink
Merge pull request #2565 from flairNLP/tars-fixes
Browse files Browse the repository at this point in the history
Tars fixes
  • Loading branch information
alanakbik authored Dec 20, 2021
2 parents 90a781a + 4113b6b commit 680485d
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
4 changes: 2 additions & 2 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,11 +1532,11 @@ def iob_iobes(tags):
if tag.value == "O" or tag.value == "":
tag.value = "O"
continue
t, label = tag.value.split("-")
t, label = tag.value.split("-", 1)
if len(tags) == i + 1 or tags[i + 1].value == "O":
next_same = False
else:
nt, next_label = tags[i + 1].value.split("-")
nt, next_label = tags[i + 1].value.split("-", 1)
next_same = nt == "I" and next_label == label
if t == "B":
if not next_same:
Expand Down
5 changes: 4 additions & 1 deletion flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@ def RNN(

def forward_loss(self, sentences: Union[List[Sentence], Sentence]) -> Tuple[torch.Tensor, int]:

# if there are no sentences, there is no loss
if len(sentences) == 0:
return torch.tensor(0., dtype=torch.float, device=flair.device, requires_grad=True), 0

# forward pass to get scores
scores, gold_labels = self.forward(sentences) # type: ignore

Expand All @@ -241,7 +245,6 @@ def forward(self, sentences: Union[List[Sentence], Sentence]):
Forward propagation through network. Returns gold labels of batch in addition.
:param sentences: Batch of current sentences
"""

self.embeddings.embed(sentences)

# make a zero-padded tensor for the whole sentence
Expand Down
2 changes: 0 additions & 2 deletions flair/models/tars_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def _get_tars_formatted_sentence(self, label, sentence):
def _get_tars_formatted_sentences(self, sentences: List[Sentence]):
label_text_pairs = []
all_labels = [label.decode("utf-8") for label in self.get_current_label_dictionary().idx2item]
# print(all_labels)
for sentence in sentences:
label_text_pairs_for_sentence = []
if self.training and self.num_negative_labels_to_sample is not None:
Expand Down Expand Up @@ -88,7 +87,6 @@ def _get_nearest_labels_for(self, labels):
import random

sample = random.sample(tags, k=self.num_negative_labels_to_sample)
# print(sample)
return sample

already_sampled_negative_labels = set()
Expand Down

0 comments on commit 680485d

Please sign in to comment.