Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated token.py for gradient calculations #2574

Merged
merged 1 commit into from
Dec 29, 2021

Conversation

vsamarths
Copy link
Contributor

The TransformersWordEmbeddings forward pass was not put within the gradient_context. Hence the forward pass was not encapsulated within the torch.no_grad() in case fine_tuning = False is set . This fix improves GPU memory management and speed.

moved fwd pass to no_grad
@alanakbik
Copy link
Collaborator

@vsamarths thanks for improving this!

@helpmefindaname can you take a look? (Since you're refactoring the TransformerEmbeddings at the moment ;))

@helpmefindaname
Copy link
Collaborator

This looks good, I've added the change to my PR

@alanakbik
Copy link
Collaborator

Great! I'll also merge this to master to recognize the contribution! Thanks a lot @vsamarths!

@alanakbik alanakbik merged commit bba5b5c into flairNLP:master Dec 29, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants