diff --git a/gensim/models/utils_any2vec.py b/gensim/models/utils_any2vec.py index 4f5396c853..563f26b8f5 100644 --- a/gensim/models/utils_any2vec.py +++ b/gensim/models/utils_any2vec.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- # # Author: Shiva Manne -# Copyright (C) 2018 RaRe Technologies s.r.o. +# Copyright (C) 2019 RaRe Technologies s.r.o. """General functions used for any2vec models. @@ -28,7 +28,7 @@ import logging from gensim import utils -from numpy import zeros, dtype, float32 as REAL, ascontiguousarray, fromstring +from numpy import zeros, dtype, float32 as REAL, ascontiguousarray, frombuffer from six.moves import range from six import iteritems, PY2 @@ -146,8 +146,83 @@ def _save_word2vec_format(fname, vocab, vectors, fvocab=None, binary=False, tota fout.write(utils.to_utf8("%s %s\n" % (word, ' '.join(repr(val) for val in row)))) +# Functions for internal use by _load_word2vec_format function + + +def _add_word_to_result(result, counts, word, weights, vocab_size): + from gensim.models.keyedvectors import Vocab + word_id = len(result.vocab) + if word in result.vocab: + logger.warning("duplicate word '%s' in word2vec file, ignoring all but first", word) + return + if counts is None: + # most common scenario: no vocab file given. just make up some bogus counts, in descending order + word_count = vocab_size - word_id + elif word in counts: + # use count from the vocab file + word_count = counts[word] + else: + logger.warning("vocabulary file is incomplete: '%s' is missing", word) + word_count = None + + result.vocab[word] = Vocab(index=word_id, count=word_count) + result.vectors[word_id] = weights + result.index2word.append(word) + + +def _add_bytes_to_result(result, counts, chunk, vocab_size, vector_size, datatype, unicode_errors): + start = 0 + processed_words = 0 + bytes_per_vector = vector_size * dtype(REAL).itemsize + max_words = vocab_size - len(result.vocab) + for _ in range(max_words): + i_space = chunk.find(b' ', start) + i_vector = i_space + 1 + + if i_space == -1 or (len(chunk) - i_vector) < bytes_per_vector: + break + + word = chunk[start:i_space].decode("utf-8", errors=unicode_errors) + # Some binary files are reported to have obsolete new line in the beginning of word, remove it + word = word.lstrip('\n') + vector = frombuffer(chunk, offset=i_vector, count=vector_size, dtype=REAL).astype(datatype) + _add_word_to_result(result, counts, word, vector, vocab_size) + start = i_vector + bytes_per_vector + processed_words += 1 + + return processed_words, chunk[start:] + + +def _word2vec_read_binary(fin, result, counts, vocab_size, vector_size, datatype, unicode_errors, binary_chunk_size): + chunk = b'' + tot_processed_words = 0 + + while tot_processed_words < vocab_size: + new_chunk = fin.read(binary_chunk_size) + chunk += new_chunk + processed_words, chunk = _add_bytes_to_result( + result, counts, chunk, vocab_size, vector_size, datatype, unicode_errors) + tot_processed_words += processed_words + if len(new_chunk) < binary_chunk_size: + break + if tot_processed_words != vocab_size: + raise EOFError("unexpected end of input; is count incorrect or file otherwise damaged?") + + +def _word2vec_read_text(fin, result, counts, vocab_size, vector_size, datatype, unicode_errors, encoding): + for line_no in range(vocab_size): + line = fin.readline() + if line == b'': + raise EOFError("unexpected end of input; is count incorrect or file otherwise damaged?") + parts = utils.to_unicode(line.rstrip(), encoding=encoding, errors=unicode_errors).split(" ") + if len(parts) != vector_size + 1: + raise ValueError("invalid vector on line %s (is this really the text format?)" % line_no) + word, weights = parts[0], [datatype(x) for x in parts[1:]] + _add_word_to_result(result, counts, word, weights, vocab_size) + + def _load_word2vec_format(cls, fname, fvocab=None, binary=False, encoding='utf8', unicode_errors='strict', - limit=None, datatype=REAL): + limit=None, datatype=REAL, binary_chunk_size=100 * 1024): """Load the input-hidden weight matrix from the original C word2vec-tool format. Note that the information stored in the file is incomplete (the binary tree is missing), @@ -176,6 +251,8 @@ def _load_word2vec_format(cls, fname, fvocab=None, binary=False, encoding='utf8' datatype : type, optional (Experimental) Can coerce dimensions to a non-default float type (such as `np.float16`) to save memory. Such types may result in much slower bulk operations or incompatibility with optimized routines.) + binary_chunk_size : int, optional + Read input file in chunks of this many bytes for performance reasons. Returns ------- @@ -183,7 +260,7 @@ def _load_word2vec_format(cls, fname, fvocab=None, binary=False, encoding='utf8' Returns the loaded model as an instance of :class:`cls`. """ - from gensim.models.keyedvectors import Vocab + counts = None if fvocab is not None: logger.info("loading word counts from %s", fvocab) @@ -203,52 +280,11 @@ def _load_word2vec_format(cls, fname, fvocab=None, binary=False, encoding='utf8' result.vector_size = vector_size result.vectors = zeros((vocab_size, vector_size), dtype=datatype) - def add_word(word, weights): - word_id = len(result.vocab) - if word in result.vocab: - logger.warning("duplicate word '%s' in %s, ignoring all but first", word, fname) - return - if counts is None: - # most common scenario: no vocab file given. just make up some bogus counts, in descending order - result.vocab[word] = Vocab(index=word_id, count=vocab_size - word_id) - elif word in counts: - # use count from the vocab file - result.vocab[word] = Vocab(index=word_id, count=counts[word]) - else: - # vocab file given, but word is missing -- set count to None (TODO: or raise?) - logger.warning("vocabulary file is incomplete: '%s' is missing", word) - result.vocab[word] = Vocab(index=word_id, count=None) - result.vectors[word_id] = weights - result.index2word.append(word) - if binary: - binary_len = dtype(REAL).itemsize * vector_size - for _ in range(vocab_size): - # mixed text and binary: read text first, then binary - word = [] - while True: - ch = fin.read(1) # Python uses I/O buffering internally - if ch == b' ': - break - if ch == b'': - raise EOFError("unexpected end of input; is count incorrect or file otherwise damaged?") - if ch != b'\n': # ignore newlines in front of words (some binary files have) - word.append(ch) - word = utils.to_unicode(b''.join(word), encoding=encoding, errors=unicode_errors) - with utils.ignore_deprecation_warning(): - # TODO use frombuffer or something similar - weights = fromstring(fin.read(binary_len), dtype=REAL).astype(datatype) - add_word(word, weights) + _word2vec_read_binary(fin, result, counts, + vocab_size, vector_size, datatype, unicode_errors, binary_chunk_size) else: - for line_no in range(vocab_size): - line = fin.readline() - if line == b'': - raise EOFError("unexpected end of input; is count incorrect or file otherwise damaged?") - parts = utils.to_unicode(line.rstrip(), encoding=encoding, errors=unicode_errors).split(" ") - if len(parts) != vector_size + 1: - raise ValueError("invalid vector on line %s (is this really the text format?)" % line_no) - word, weights = parts[0], [datatype(x) for x in parts[1:]] - add_word(word, weights) + _word2vec_read_text(fin, result, counts, vocab_size, vector_size, datatype, unicode_errors, encoding) if result.vectors.shape[0] != len(result.vocab): logger.info( "duplicate words detected, shrinking matrix size from %i to %i", diff --git a/gensim/test/test_utils_any2vec.py b/gensim/test/test_utils_any2vec.py new file mode 100644 index 0000000000..f4c5c2c430 --- /dev/null +++ b/gensim/test/test_utils_any2vec.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (C) 2017 Radim Rehurek +# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html + +""" +Automated tests for checking utils_any2vec functionality. +""" + +import logging +import unittest + +import numpy as np + +import gensim.utils +import gensim.test.utils + +import gensim.models.utils_any2vec + + +logger = logging.getLogger(__name__) + + +def save_dict_to_word2vec_formated_file(fname, word2vec_dict): + + with gensim.utils.open(fname, "bw") as f: + + num_words = len(word2vec_dict) + vector_length = len(list(word2vec_dict.values())[0]) + + header = "%d %d\n" % (num_words, vector_length) + f.write(header.encode(encoding="ascii")) + + for word, vector in word2vec_dict.items(): + f.write(word.encode()) + f.write(' '.encode()) + f.write(np.array(vector).astype(np.float32).tobytes()) + + +class LoadWord2VecFormatTest(unittest.TestCase): + + def assert_dict_equal_to_model(self, d, m): + self.assertEqual(len(d), len(m.vocab)) + + for word in d.keys(): + self.assertSequenceEqual(list(d[word]), list(m[word])) + + def verify_load2vec_binary_result(self, w2v_dict, binary_chunk_size, limit): + tmpfile = gensim.test.utils.get_tmpfile("tmp_w2v") + save_dict_to_word2vec_formated_file(tmpfile, w2v_dict) + w2v_model = \ + gensim.models.utils_any2vec._load_word2vec_format( + cls=gensim.models.KeyedVectors, + fname=tmpfile, + binary=True, + limit=limit, + binary_chunk_size=binary_chunk_size) + if limit is None: + limit = len(w2v_dict) + + w2v_keys_postprocessed = list(w2v_dict.keys())[:limit] + w2v_dict_postprocessed = {k.lstrip(): w2v_dict[k] for k in w2v_keys_postprocessed} + + self.assert_dict_equal_to_model(w2v_dict_postprocessed, w2v_model) + + def test_load_word2vec_format_basic(self): + w2v_dict = {"abc": [1, 2, 3], + "cde": [4, 5, 6], + "def": [7, 8, 9]} + self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=None) + self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=16, limit=None) + self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=1024, limit=None) + + w2v_dict = {"abc": [1, 2, 3], + "cdefg": [4, 5, 6], + "d": [7, 8, 9]} + self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=None) + self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=16, limit=None) + self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=1024, limit=None) + + def test_load_word2vec_format_limit(self): + w2v_dict = {"abc": [1, 2, 3], + "cde": [4, 5, 6], + "def": [7, 8, 9]} + self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=1) + self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=16, limit=1) + self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=1024, limit=1) + + w2v_dict = {"abc": [1, 2, 3], + "cde": [4, 5, 6], + "def": [7, 8, 9]} + self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=2) + self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=16, limit=2) + self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=1024, limit=2) + + w2v_dict = {"abc": [1, 2, 3], + "cdefg": [4, 5, 6], + "d": [7, 8, 9]} + + self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=1) + self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=16, limit=1) + self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=1024, limit=1) + + w2v_dict = {"abc": [1, 2, 3], + "cdefg": [4, 5, 6], + "d": [7, 8, 9]} + self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=2) + self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=16, limit=2) + self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=1024, limit=2) + + def test_load_word2vec_format_space_stripping(self): + w2v_dict = {"\nabc": [1, 2, 3], + "cdefdg": [4, 5, 6], + "\n\ndef": [7, 8, 9]} + self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=None) + self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=1) + + +if __name__ == '__main__': + logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG) + unittest.main()