Skip to content

Commit

Permalink
Optimize FastText.load_fasttext_model (#2340)
Browse files Browse the repository at this point in the history
* add docstring for Model namedtuple

* add option to skip hidden matrix loading

* review response: rename fast -> full_model

* speed up hash function based on ideas from @horpto and @menshikh-iv

* remove obsolete ft_hash function

* review response: update docstring

* attempt to hack around appveyor Py2.7 build missing stdint.h

* fixup: add missing int8_t typedef

* review response: avoid split and join

* review response: add comment to explain hack

* review response: improve logging message

* review response: fix hash_main function

* fixup: fix test_utils.py

* add tests for ngram generation

* fixup in tests

* add emoji test case

* minor fixup in logging message

* add byte tests

* remove FIXME, absense of ord does not influence correctness

* review response: introduce list slicing

* avoid using fstrings for Py2 compatibility

* flake8

* more Py2 compatibility

* flake8

* review response: get rid of set()

* review response: remove excess bytes() call

* fix tests (wide unicode issue)

* add test against actual FB implementation

* adding temporary benchmarking code

* replacing non-optimized code with optimized code

* removing temporary benchmarking code

* remove wide characters from fb test code
  • Loading branch information
mpenkov authored and menshikh-iv committed Jan 24, 2019
1 parent 0cc0994 commit 411f546
Show file tree
Hide file tree
Showing 10 changed files with 2,544 additions and 2,466 deletions.
58 changes: 53 additions & 5 deletions gensim/models/_fasttext_bin.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,49 @@ def _yield_field_names():

_FIELD_NAMES = sorted(set(_yield_field_names()))
Model = collections.namedtuple('Model', _FIELD_NAMES)
"""Holds data loaded from the Facebook binary.
Fields
------
dim : int
The dimensionality of the vectors.
ws : int
The window size.
epoch : int
The number of training epochs.
neg : int
If non-zero, indicates that the model uses negative sampling.
loss : int
If equal to 1, indicates that the model uses hierarchical sampling.
model : int
If equal to 2, indicates that the model uses skip-grams.
bucket : int
The number of buckets.
min_count : int
The threshold below which the model ignores terms.
t : float
The sample threshold.
minn : int
The minimum ngram length.
maxn : int
The maximum ngram length.
raw_vocab : collections.OrderedDict
A map from words (str) to their frequency (int). The order in the dict
corresponds to the order of the words in the Facebook binary.
nwords : int
The number of words.
vocab_size : int
The size of the vocabulary.
vectors_ngrams : numpy.array
This is a matrix that contains vectors learned by the model.
Each row corresponds to a vector.
The number of vectors is equal to the number of words plus the number of buckets.
The number of columns is equal to the vector dimensionality.
hidden_output : numpy.array
This is a matrix that contains the shallow neural network output.
This array has the same dimensions as vectors_ngrams.
May be None - in that case, it is impossible to continue training the model.
"""


def _struct_unpack(fin, fmt):
Expand Down Expand Up @@ -177,7 +220,7 @@ def _load_matrix(fin, new_format=True):
return matrix


def load(fin, encoding='utf-8'):
def load(fin, encoding='utf-8', full_model=True):
"""Load a model from a binary stream.
Parameters
Expand All @@ -186,6 +229,9 @@ def load(fin, encoding='utf-8'):
The readable binary stream.
encoding : str, optional
The encoding to use for decoding text
full_model : boolean, optional
If False, skips loading the hidden output matrix. This saves a fair bit
of CPU time and RAM, but prevents training continuation.
Returns
-------
Expand All @@ -209,10 +255,12 @@ def load(fin, encoding='utf-8'):

vectors_ngrams = _load_matrix(fin, new_format=new_format)

hidden_output = _load_matrix(fin, new_format=new_format)
model.update(vectors_ngrams=vectors_ngrams, hidden_output=hidden_output)

assert fin.read() == b'', 'expected to reach EOF'
if not full_model:
hidden_output = None
else:
hidden_output = _load_matrix(fin, new_format=new_format)
assert fin.read() == b'', 'expected to reach EOF'

model.update(vectors_ngrams=vectors_ngrams, hidden_output=hidden_output)
model = {k: v for k, v in model.items() if k in _FIELD_NAMES}
return Model(**model)
Loading

0 comments on commit 411f546

Please sign in to comment.