Skip to content

Commit

Permalink
Merge branch 'master' into flairNLPgh-3474/add-random-seed-to-datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
alanakbik authored Jul 1, 2024
2 parents 67cfe08 + 59bd705 commit 6ca537c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
6 changes: 3 additions & 3 deletions flair/models/sequence_tagger_utils/viterbi.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,16 +221,16 @@ def decode(
)

if probabilities_for_all_classes:
all_tags = self._all_scores_for_token(scores.cpu(), tag_seq, lengths, sentences)
all_tags = self._all_scores_for_token(scores.cpu(), decoded.cpu(), lengths, sentences)

return tags, all_tags

def _all_scores_for_token(
self, scores: torch.Tensor, tag_seq: torch.IntTensor, lengths: torch.IntTensor, sentences: List[Sentence]
self, scores: torch.Tensor, tag_sequences: torch.Tensor, lengths: torch.IntTensor, sentences: List[Sentence]
):
"""Returns all scores for each tag in tag dictionary."""
scores = scores.numpy()
for i_batch, batch in enumerate(scores):
for i_batch, (batch, tag_seq) in enumerate(zip(scores, tag_sequences)):
for i, (tag_id, tag_scores) in enumerate(zip(tag_seq, batch)):
tag_id_int = tag_id if isinstance(tag_id, int) else int(tag_id.item())

Expand Down
3 changes: 3 additions & 0 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,9 @@ def train_custom(
self.return_values["test_score"] = test_results.main_score

else:
if (base_path / "best-model.pt").exists():
log.info("Loading model from best epoch ...")
self.model.load_state_dict(self.model.load(base_path / "best-model.pt").state_dict())
self.return_values["test_score"] = 0
log.info("Test data not provided setting final score to 0")

Expand Down

0 comments on commit 6ca537c

Please sign in to comment.