-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Seralization of embeddings #3011
Seralization of embeddings #3011
Conversation
10d0726
to
3c5eabe
Compare
Status of Embeddings:
|
…t the naming better and add loading of jit models
d8b067b
to
f973806
Compare
Hello @helpmefindaname this is really cool, thanks for creating this! Some initial thoughts for discussion:
model = FlairClassifier.load("ner")
model = FlairClassifier.load("ner", "pos", "relations") it would load a whole pipeline that when calling |
from flair.data import Dictionary | ||
from flair.nn.recurrent import create_recurrent_layer | ||
|
||
|
||
@AutoFlairModel.register |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the LanguageModel registered as AutoFlairModel?
Thanks again for improving this @helpmefindaname! Regarding our discussion on whether/how to merge |
To be saver in regards of pickle, I propose to use a dict-format to store all properties required to recreate the embeddings (weights are stored with the model itself anyways).
This allows opening Flairmodels with incompatible parameters via
torch.load(...)
and therefore allows debugging version conflicts.During development I also found & fixed the following issues:
DocumentLMEmbeddings
were not providing the right names for their embeddings. So taking the correct usage ofdoc_lm_embedding.embedd(sentence);sentence.get_embeddings(doc_lm_embedding.get_names())
Would result into an empty tensor.train()
method didn't call it's super method, the.eval()
call in the__init__
was negated, leading to dropout staying enabled as that is the default..eval()
mode after creating.from flair.models import TextRegressior
This also implements two classes
AutoFlairModel
andAutoFlairClassifier
which can be used to to load any model, given that their type is clear.Example usages are here:
The difference between
AutoFlairModel
andAutoFlairClassifier
is thatAutoFlairClassifier
is limited to only classifers (no text-regressor) while it provides stronger typing hints (all methods the Classifier provides extra, e.g.:predict
)Potential issues are:
model = SequenceTagger.load("my-model.pt")
tomodel = AutoFlairClassifier.load("my-model.pt")
I would recommend loading it once and saving it again on the newest version.