Skip to content

Commit

Permalink
Merge pull request #2574 from vsamarths/master
Browse files Browse the repository at this point in the history
Updated token.py for gradient calculations
  • Loading branch information
alanakbik authored Dec 29, 2021
2 parents 680485d + 7b1dff8 commit bba5b5c
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions flair/embeddings/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,18 +1145,18 @@ def _add_embeddings_internal(self, sentences: List[Sentence]):
lang_id = self.tokenizer.lang2id.get(sentences[s_id].get_language_code(), 0) # type: ignore
model_kwargs["langs"][s_id][:sequence_length] = lang_id

# put encoded batch through transformer model to get all hidden states of all encoder layers
hidden_states = self.model(input_ids, **model_kwargs)[-1]
# make the tuple a tensor; makes working with it easier.
hidden_states = torch.stack(hidden_states)

sentence_idx_offset = 0


# gradients are enabled if fine-tuning is enabled
gradient_context = torch.enable_grad() if (self.fine_tune and self.training) else torch.no_grad()

with gradient_context:
# put encoded batch through transformer model to get all hidden states of all encoder layers
hidden_states = self.model(input_ids, **model_kwargs)[-1]
# make the tuple a tensor; makes working with it easier.
hidden_states = torch.stack(hidden_states)

sentence_idx_offset = 0
# iterate over all subtokenized sentences
for sentence_idx, (
sentence,
Expand Down

0 comments on commit bba5b5c

Please sign in to comment.