Skip to content

Commit

Permalink
Speed up word2vec model loading (#2671)
Browse files Browse the repository at this point in the history
* Speed up word2vec binary model loading (#2642)

* Add correctness tests for optimized word2vec model loading (#2642)

* Include remarks of Radim to code speeding up vectors loading (#2671)

* Include remarks of Michael to code speeding up vectors loading (#2671)

* Refactor _load_word2vec_format into a few functions for better readability

* Clean-up _add_word_to_result function
  • Loading branch information
lopusz authored and piskvorky committed Nov 18, 2019
1 parent f72a55d commit 1052b9b
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 48 deletions.
132 changes: 84 additions & 48 deletions gensim/models/utils_any2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# Author: Shiva Manne <s.manne@rare-technologies.com>
# Copyright (C) 2018 RaRe Technologies s.r.o.
# Copyright (C) 2019 RaRe Technologies s.r.o.

"""General functions used for any2vec models.
Expand All @@ -28,7 +28,7 @@
import logging
from gensim import utils

from numpy import zeros, dtype, float32 as REAL, ascontiguousarray, fromstring
from numpy import zeros, dtype, float32 as REAL, ascontiguousarray, frombuffer

from six.moves import range
from six import iteritems, PY2
Expand Down Expand Up @@ -146,8 +146,83 @@ def _save_word2vec_format(fname, vocab, vectors, fvocab=None, binary=False, tota
fout.write(utils.to_utf8("%s %s\n" % (word, ' '.join(repr(val) for val in row))))


# Functions for internal use by _load_word2vec_format function


def _add_word_to_result(result, counts, word, weights, vocab_size):
from gensim.models.keyedvectors import Vocab
word_id = len(result.vocab)
if word in result.vocab:
logger.warning("duplicate word '%s' in word2vec file, ignoring all but first", word)
return
if counts is None:
# most common scenario: no vocab file given. just make up some bogus counts, in descending order
word_count = vocab_size - word_id
elif word in counts:
# use count from the vocab file
word_count = counts[word]
else:
logger.warning("vocabulary file is incomplete: '%s' is missing", word)
word_count = None

result.vocab[word] = Vocab(index=word_id, count=word_count)
result.vectors[word_id] = weights
result.index2word.append(word)


def _add_bytes_to_result(result, counts, chunk, vocab_size, vector_size, datatype, unicode_errors):
start = 0
processed_words = 0
bytes_per_vector = vector_size * dtype(REAL).itemsize
max_words = vocab_size - len(result.vocab)
for _ in range(max_words):
i_space = chunk.find(b' ', start)
i_vector = i_space + 1

if i_space == -1 or (len(chunk) - i_vector) < bytes_per_vector:
break

word = chunk[start:i_space].decode("utf-8", errors=unicode_errors)
# Some binary files are reported to have obsolete new line in the beginning of word, remove it
word = word.lstrip('\n')
vector = frombuffer(chunk, offset=i_vector, count=vector_size, dtype=REAL).astype(datatype)
_add_word_to_result(result, counts, word, vector, vocab_size)
start = i_vector + bytes_per_vector
processed_words += 1

return processed_words, chunk[start:]


def _word2vec_read_binary(fin, result, counts, vocab_size, vector_size, datatype, unicode_errors, binary_chunk_size):
chunk = b''
tot_processed_words = 0

while tot_processed_words < vocab_size:
new_chunk = fin.read(binary_chunk_size)
chunk += new_chunk
processed_words, chunk = _add_bytes_to_result(
result, counts, chunk, vocab_size, vector_size, datatype, unicode_errors)
tot_processed_words += processed_words
if len(new_chunk) < binary_chunk_size:
break
if tot_processed_words != vocab_size:
raise EOFError("unexpected end of input; is count incorrect or file otherwise damaged?")


def _word2vec_read_text(fin, result, counts, vocab_size, vector_size, datatype, unicode_errors, encoding):
for line_no in range(vocab_size):
line = fin.readline()
if line == b'':
raise EOFError("unexpected end of input; is count incorrect or file otherwise damaged?")
parts = utils.to_unicode(line.rstrip(), encoding=encoding, errors=unicode_errors).split(" ")
if len(parts) != vector_size + 1:
raise ValueError("invalid vector on line %s (is this really the text format?)" % line_no)
word, weights = parts[0], [datatype(x) for x in parts[1:]]
_add_word_to_result(result, counts, word, weights, vocab_size)


def _load_word2vec_format(cls, fname, fvocab=None, binary=False, encoding='utf8', unicode_errors='strict',
limit=None, datatype=REAL):
limit=None, datatype=REAL, binary_chunk_size=100 * 1024):
"""Load the input-hidden weight matrix from the original C word2vec-tool format.
Note that the information stored in the file is incomplete (the binary tree is missing),
Expand Down Expand Up @@ -176,14 +251,16 @@ def _load_word2vec_format(cls, fname, fvocab=None, binary=False, encoding='utf8'
datatype : type, optional
(Experimental) Can coerce dimensions to a non-default float type (such as `np.float16`) to save memory.
Such types may result in much slower bulk operations or incompatibility with optimized routines.)
binary_chunk_size : int, optional
Read input file in chunks of this many bytes for performance reasons.
Returns
-------
object
Returns the loaded model as an instance of :class:`cls`.
"""
from gensim.models.keyedvectors import Vocab

counts = None
if fvocab is not None:
logger.info("loading word counts from %s", fvocab)
Expand All @@ -203,52 +280,11 @@ def _load_word2vec_format(cls, fname, fvocab=None, binary=False, encoding='utf8'
result.vector_size = vector_size
result.vectors = zeros((vocab_size, vector_size), dtype=datatype)

def add_word(word, weights):
word_id = len(result.vocab)
if word in result.vocab:
logger.warning("duplicate word '%s' in %s, ignoring all but first", word, fname)
return
if counts is None:
# most common scenario: no vocab file given. just make up some bogus counts, in descending order
result.vocab[word] = Vocab(index=word_id, count=vocab_size - word_id)
elif word in counts:
# use count from the vocab file
result.vocab[word] = Vocab(index=word_id, count=counts[word])
else:
# vocab file given, but word is missing -- set count to None (TODO: or raise?)
logger.warning("vocabulary file is incomplete: '%s' is missing", word)
result.vocab[word] = Vocab(index=word_id, count=None)
result.vectors[word_id] = weights
result.index2word.append(word)

if binary:
binary_len = dtype(REAL).itemsize * vector_size
for _ in range(vocab_size):
# mixed text and binary: read text first, then binary
word = []
while True:
ch = fin.read(1) # Python uses I/O buffering internally
if ch == b' ':
break
if ch == b'':
raise EOFError("unexpected end of input; is count incorrect or file otherwise damaged?")
if ch != b'\n': # ignore newlines in front of words (some binary files have)
word.append(ch)
word = utils.to_unicode(b''.join(word), encoding=encoding, errors=unicode_errors)
with utils.ignore_deprecation_warning():
# TODO use frombuffer or something similar
weights = fromstring(fin.read(binary_len), dtype=REAL).astype(datatype)
add_word(word, weights)
_word2vec_read_binary(fin, result, counts,
vocab_size, vector_size, datatype, unicode_errors, binary_chunk_size)
else:
for line_no in range(vocab_size):
line = fin.readline()
if line == b'':
raise EOFError("unexpected end of input; is count incorrect or file otherwise damaged?")
parts = utils.to_unicode(line.rstrip(), encoding=encoding, errors=unicode_errors).split(" ")
if len(parts) != vector_size + 1:
raise ValueError("invalid vector on line %s (is this really the text format?)" % line_no)
word, weights = parts[0], [datatype(x) for x in parts[1:]]
add_word(word, weights)
_word2vec_read_text(fin, result, counts, vocab_size, vector_size, datatype, unicode_errors, encoding)
if result.vectors.shape[0] != len(result.vocab):
logger.info(
"duplicate words detected, shrinking matrix size from %i to %i",
Expand Down
122 changes: 122 additions & 0 deletions gensim/test/test_utils_any2vec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2017 Radim Rehurek <me@radimrehurek.com>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html

"""
Automated tests for checking utils_any2vec functionality.
"""

import logging
import unittest

import numpy as np

import gensim.utils
import gensim.test.utils

import gensim.models.utils_any2vec


logger = logging.getLogger(__name__)


def save_dict_to_word2vec_formated_file(fname, word2vec_dict):

with gensim.utils.open(fname, "bw") as f:

num_words = len(word2vec_dict)
vector_length = len(list(word2vec_dict.values())[0])

header = "%d %d\n" % (num_words, vector_length)
f.write(header.encode(encoding="ascii"))

for word, vector in word2vec_dict.items():
f.write(word.encode())
f.write(' '.encode())
f.write(np.array(vector).astype(np.float32).tobytes())


class LoadWord2VecFormatTest(unittest.TestCase):

def assert_dict_equal_to_model(self, d, m):
self.assertEqual(len(d), len(m.vocab))

for word in d.keys():
self.assertSequenceEqual(list(d[word]), list(m[word]))

def verify_load2vec_binary_result(self, w2v_dict, binary_chunk_size, limit):
tmpfile = gensim.test.utils.get_tmpfile("tmp_w2v")
save_dict_to_word2vec_formated_file(tmpfile, w2v_dict)
w2v_model = \
gensim.models.utils_any2vec._load_word2vec_format(
cls=gensim.models.KeyedVectors,
fname=tmpfile,
binary=True,
limit=limit,
binary_chunk_size=binary_chunk_size)
if limit is None:
limit = len(w2v_dict)

w2v_keys_postprocessed = list(w2v_dict.keys())[:limit]
w2v_dict_postprocessed = {k.lstrip(): w2v_dict[k] for k in w2v_keys_postprocessed}

self.assert_dict_equal_to_model(w2v_dict_postprocessed, w2v_model)

def test_load_word2vec_format_basic(self):
w2v_dict = {"abc": [1, 2, 3],
"cde": [4, 5, 6],
"def": [7, 8, 9]}
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=None)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=16, limit=None)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=1024, limit=None)

w2v_dict = {"abc": [1, 2, 3],
"cdefg": [4, 5, 6],
"d": [7, 8, 9]}
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=None)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=16, limit=None)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=1024, limit=None)

def test_load_word2vec_format_limit(self):
w2v_dict = {"abc": [1, 2, 3],
"cde": [4, 5, 6],
"def": [7, 8, 9]}
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=1)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=16, limit=1)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=1024, limit=1)

w2v_dict = {"abc": [1, 2, 3],
"cde": [4, 5, 6],
"def": [7, 8, 9]}
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=2)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=16, limit=2)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=1024, limit=2)

w2v_dict = {"abc": [1, 2, 3],
"cdefg": [4, 5, 6],
"d": [7, 8, 9]}

self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=1)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=16, limit=1)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=1024, limit=1)

w2v_dict = {"abc": [1, 2, 3],
"cdefg": [4, 5, 6],
"d": [7, 8, 9]}
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=2)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=16, limit=2)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=1024, limit=2)

def test_load_word2vec_format_space_stripping(self):
w2v_dict = {"\nabc": [1, 2, 3],
"cdefdg": [4, 5, 6],
"\n\ndef": [7, 8, 9]}
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=None)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=1)


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
unittest.main()

0 comments on commit 1052b9b

Please sign in to comment.