Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up word2vec model loading #2671

Merged
merged 6 commits into from
Nov 18, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 67 additions & 40 deletions gensim/models/utils_any2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import logging
from gensim import utils

from numpy import zeros, dtype, float32 as REAL, ascontiguousarray, fromstring
from numpy import zeros, dtype, float32 as REAL, ascontiguousarray, frombuffer

from six.moves import range
from six import iteritems, PY2
Expand Down Expand Up @@ -147,7 +147,7 @@ def _save_word2vec_format(fname, vocab, vectors, fvocab=None, binary=False, tota


def _load_word2vec_format(cls, fname, fvocab=None, binary=False, encoding='utf8', unicode_errors='strict',
limit=None, datatype=REAL):
limit=None, datatype=REAL, binary_chunk_size=100 * 1024):
"""Load the input-hidden weight matrix from the original C word2vec-tool format.

Note that the information stored in the file is incomplete (the binary tree is missing),
Expand Down Expand Up @@ -176,14 +176,64 @@ def _load_word2vec_format(cls, fname, fvocab=None, binary=False, encoding='utf8'
datatype : type, optional
(Experimental) Can coerce dimensions to a non-default float type (such as `np.float16`) to save memory.
Such types may result in much slower bulk operations or incompatibility with optimized routines.)
binary_chunk_size : int, optional
Size of chunk in which binary files are read. Used mostly for testing. Defalut value 100 kB.
piskvorky marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
object
Returns the loaded model as an instance of :class:`cls`.

"""

def __add_word_to_result(result, counts, word, weights):
piskvorky marked this conversation as resolved.
Show resolved Hide resolved
word_id = len(result.vocab)
if word in result.vocab:
logger.warning("duplicate word '%s' in %s, ignoring all but first", word, fname)
return
if counts is None:
mpenkov marked this conversation as resolved.
Show resolved Hide resolved
# most common scenario: no vocab file given. just make up some bogus counts, in descending order
result.vocab[word] = Vocab(index=word_id, count=vocab_size - word_id)
elif word in counts:
# use count from the vocab file
result.vocab[word] = Vocab(index=word_id, count=counts[word])
else:
# vocab file given, but word is missing -- set count to None (TODO: or raise?)
logger.warning("vocabulary file is incomplete: '%s' is missing", word)
result.vocab[word] = Vocab(index=word_id, count=None)
result.vectors[word_id] = weights
result.index2word.append(word)

def __remove_initial_new_line(s):
piskvorky marked this conversation as resolved.
Show resolved Hide resolved
i = 0
while i < len(s) and s[i] == '\n':
i += 1
return s[i:]

def __add_words_from_binary_chunk_to_result(result, counts, max_words, chunk, vector_size, datatype):
piskvorky marked this conversation as resolved.
Show resolved Hide resolved
start = 0
n = len(chunk)
piskvorky marked this conversation as resolved.
Show resolved Hide resolved
processed_words = 0
n_bytes_per_vector = vector_size * dtype(REAL).itemsize
piskvorky marked this conversation as resolved.
Show resolved Hide resolved

for _ in range(0, max_words):
piskvorky marked this conversation as resolved.
Show resolved Hide resolved
i_space = chunk.find(b' ', start)
i_vector = i_space + 1
if i_space != -1 and (n - i_vector) >= n_bytes_per_vector:
mpenkov marked this conversation as resolved.
Show resolved Hide resolved
word = chunk[start:i_space].decode("utf-8", errors=unicode_errors)
# Some binary files are reported to have obsolete new line in the beginning of word, remove it
word = __remove_initial_new_line(word)
vector = frombuffer(chunk, offset=i_vector, count=vector_size, dtype=REAL).astype(datatype)
__add_word_to_result(result, counts, word, vector)
start = i_vector + n_bytes_per_vector
processed_words += 1
else:
break

return processed_words, chunk[start:]

from gensim.models.keyedvectors import Vocab

counts = None
if fvocab is not None:
logger.info("loading word counts from %s", fvocab)
Expand All @@ -192,7 +242,6 @@ def _load_word2vec_format(cls, fname, fvocab=None, binary=False, encoding='utf8'
for line in fin:
word, count = utils.to_unicode(line, errors=unicode_errors).strip().split()
counts[word] = int(count)

piskvorky marked this conversation as resolved.
Show resolved Hide resolved
logger.info("loading projection weights from %s", fname)
with utils.open(fname, 'rb') as fin:
header = utils.to_unicode(fin.readline(), encoding=encoding)
Expand All @@ -202,43 +251,21 @@ def _load_word2vec_format(cls, fname, fvocab=None, binary=False, encoding='utf8'
result = cls(vector_size)
result.vector_size = vector_size
result.vectors = zeros((vocab_size, vector_size), dtype=datatype)

def add_word(word, weights):
word_id = len(result.vocab)
if word in result.vocab:
logger.warning("duplicate word '%s' in %s, ignoring all but first", word, fname)
return
if counts is None:
# most common scenario: no vocab file given. just make up some bogus counts, in descending order
result.vocab[word] = Vocab(index=word_id, count=vocab_size - word_id)
elif word in counts:
# use count from the vocab file
result.vocab[word] = Vocab(index=word_id, count=counts[word])
else:
# vocab file given, but word is missing -- set count to None (TODO: or raise?)
logger.warning("vocabulary file is incomplete: '%s' is missing", word)
result.vocab[word] = Vocab(index=word_id, count=None)
result.vectors[word_id] = weights
result.index2word.append(word)

if binary:
mpenkov marked this conversation as resolved.
Show resolved Hide resolved
binary_len = dtype(REAL).itemsize * vector_size
for _ in range(vocab_size):
# mixed text and binary: read text first, then binary
word = []
while True:
ch = fin.read(1) # Python uses I/O buffering internally
if ch == b' ':
break
if ch == b'':
raise EOFError("unexpected end of input; is count incorrect or file otherwise damaged?")
if ch != b'\n': # ignore newlines in front of words (some binary files have)
word.append(ch)
word = utils.to_unicode(b''.join(word), encoding=encoding, errors=unicode_errors)
with utils.ignore_deprecation_warning():
# TODO use frombuffer or something similar
weights = fromstring(fin.read(binary_len), dtype=REAL).astype(datatype)
add_word(word, weights)
chunk = b''
tot_processed_words = 0

while tot_processed_words < vocab_size:
new_chunk = fin.read(binary_chunk_size)
chunk += new_chunk
max_words = vocab_size - len(result.vocab)
processed_words, chunk = __add_words_from_binary_chunk_to_result(result, counts, max_words,
chunk, vector_size, datatype)
piskvorky marked this conversation as resolved.
Show resolved Hide resolved
tot_processed_words += processed_words
if len(new_chunk) < binary_chunk_size:
break
if tot_processed_words != vocab_size:
raise EOFError("unexpected end of input; is count incorrect or file otherwise damaged?")
piskvorky marked this conversation as resolved.
Show resolved Hide resolved
else:
for line_no in range(vocab_size):
line = fin.readline()
Expand All @@ -248,7 +275,7 @@ def add_word(word, weights):
if len(parts) != vector_size + 1:
raise ValueError("invalid vector on line %s (is this really the text format?)" % line_no)
word, weights = parts[0], [datatype(x) for x in parts[1:]]
add_word(word, weights)
__add_word_to_result(result, counts, word, weights)
if result.vectors.shape[0] != len(result.vocab):
logger.info(
"duplicate words detected, shrinking matrix size from %i to %i",
Expand Down
122 changes: 122 additions & 0 deletions gensim/test/test_utils_any2vec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2017 Radim Rehurek <me@radimrehurek.com>
piskvorky marked this conversation as resolved.
Show resolved Hide resolved
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html

"""
Automated tests for checking utils_any2vec functionality.
"""

import logging
import unittest

import numpy as np

import gensim.utils
import gensim.test.utils

import gensim.models.utils_any2vec


logger = logging.getLogger(__name__)


def save_dict_to_word2vec_formated_file(fname, word2vec_dict):

with gensim.utils.open(fname, "bw") as f:

num_words = len(word2vec_dict)
vector_length = len(list(word2vec_dict.values())[0])

header = "%d %d\n" % (num_words, vector_length)
f.write(header.encode(encoding="ascii"))

for word, vector in word2vec_dict.items():
f.write(word.encode())
f.write(' '.encode())
f.write(np.array(vector).astype(np.float32).tobytes())


class LoadWord2VecFormatTest(unittest.TestCase):

def assert_dict_equal_to_model(self, d, m):
self.assertEqual(len(d), len(m.vocab))

for word in d.keys():
self.assertSequenceEqual(list(d[word]), list(m[word]))

def verify_load2vec_binary_result(self, w2v_dict, binary_chunk_size, limit):
tmpfile = gensim.test.utils.get_tmpfile("tmp_w2v")
save_dict_to_word2vec_formated_file(tmpfile, w2v_dict)
w2v_model = \
gensim.models.utils_any2vec._load_word2vec_format(
cls=gensim.models.KeyedVectors,
fname=tmpfile,
binary=True,
limit=limit,
binary_chunk_size=binary_chunk_size)
if limit is None:
limit = len(w2v_dict)

w2v_keys_postprocessed = list(w2v_dict.keys())[:limit]
w2v_dict_postprocessed = {k.lstrip(): w2v_dict[k] for k in w2v_keys_postprocessed}

self.assert_dict_equal_to_model(w2v_dict_postprocessed, w2v_model)

def test_load_word2vec_format_basic(self):
w2v_dict = {"abc": [1, 2, 3],
"cde": [4, 5, 6],
"def": [7, 8, 9]}
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=None)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=16, limit=None)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=1024, limit=None)

w2v_dict = {"abc": [1, 2, 3],
"cdefg": [4, 5, 6],
"d": [7, 8, 9]}
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=None)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=16, limit=None)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=1024, limit=None)

def test_load_word2vec_format_limit(self):
w2v_dict = {"abc": [1, 2, 3],
"cde": [4, 5, 6],
"def": [7, 8, 9]}
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=1)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=16, limit=1)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=1024, limit=1)

w2v_dict = {"abc": [1, 2, 3],
"cde": [4, 5, 6],
"def": [7, 8, 9]}
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=2)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=16, limit=2)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=1024, limit=2)

w2v_dict = {"abc": [1, 2, 3],
"cdefg": [4, 5, 6],
"d": [7, 8, 9]}

self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=1)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=16, limit=1)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=1024, limit=1)

w2v_dict = {"abc": [1, 2, 3],
"cdefg": [4, 5, 6],
"d": [7, 8, 9]}
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=2)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=16, limit=2)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=1024, limit=2)

def test_load_word2vec_format_space_stripping(self):
w2v_dict = {"\nabc": [1, 2, 3],
"cdefdg": [4, 5, 6],
"\n\ndef": [7, 8, 9]}
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=None)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=1)


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
unittest.main()