From 7b1dff854d861cc24ad4fe57f9895fa392ab74aa Mon Sep 17 00:00:00 2001 From: vsamarths Date: Tue, 28 Dec 2021 11:38:39 +0530 Subject: [PATCH] Update token.py moved fwd pass to no_grad --- flair/embeddings/token.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/flair/embeddings/token.py b/flair/embeddings/token.py index 3111210ea..0d81be0c1 100644 --- a/flair/embeddings/token.py +++ b/flair/embeddings/token.py @@ -1145,18 +1145,18 @@ def _add_embeddings_internal(self, sentences: List[Sentence]): lang_id = self.tokenizer.lang2id.get(sentences[s_id].get_language_code(), 0) # type: ignore model_kwargs["langs"][s_id][:sequence_length] = lang_id - # put encoded batch through transformer model to get all hidden states of all encoder layers - hidden_states = self.model(input_ids, **model_kwargs)[-1] - # make the tuple a tensor; makes working with it easier. - hidden_states = torch.stack(hidden_states) - - sentence_idx_offset = 0 + # gradients are enabled if fine-tuning is enabled gradient_context = torch.enable_grad() if (self.fine_tune and self.training) else torch.no_grad() with gradient_context: + # put encoded batch through transformer model to get all hidden states of all encoder layers + hidden_states = self.model(input_ids, **model_kwargs)[-1] + # make the tuple a tensor; makes working with it easier. + hidden_states = torch.stack(hidden_states) + sentence_idx_offset = 0 # iterate over all subtokenized sentences for sentence_idx, ( sentence,