From e031fc4dccc87a4f922087ed6271f61fcb1bd5a3 Mon Sep 17 00:00:00 2001 From: Ivan Menshikh Date: Sun, 1 Mar 2020 11:19:23 +0300 Subject: [PATCH] parametrize datatype & cast embeddings passed to `add` to KV datatype --- gensim/models/keyedvectors.py | 15 +++++++-------- gensim/models/poincare.py | 4 ++-- gensim/test/test_keyedvectors.py | 9 +++++---- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/gensim/models/keyedvectors.py b/gensim/models/keyedvectors.py index 13c9373744..d22df999f4 100644 --- a/gensim/models/keyedvectors.py +++ b/gensim/models/keyedvectors.py @@ -214,8 +214,8 @@ def __str__(self): class BaseKeyedVectors(utils.SaveLoad): """Abstract base class / interface for various types of word vectors.""" - def __init__(self, vector_size): - self.vectors = zeros((0, vector_size)) + def __init__(self, vector_size, dtype=REAL): + self.vectors = zeros((0, vector_size), dtype=dtype) self.vocab = {} self.vector_size = vector_size self.index2entity = [] @@ -308,8 +308,7 @@ def add(self, entities, weights, replace=False): self.index2entity.append(entity) # add vectors for new entities - self.vectors = self.vectors.astype(weights.dtype) # cast existing vectors to 'weights' type - self.vectors = vstack((self.vectors, weights[~in_vocab_mask])) + self.vectors = vstack((self.vectors, weights[~in_vocab_mask].astype(self.vectors.dtype))) # change vectors for in_vocab entities if `replace` flag is specified if replace: @@ -377,8 +376,8 @@ def rank(self, entity1, entity2): class WordEmbeddingsKeyedVectors(BaseKeyedVectors): """Class containing common methods for operations over word vectors.""" - def __init__(self, vector_size): - super(WordEmbeddingsKeyedVectors, self).__init__(vector_size=vector_size) + def __init__(self, vector_size, dtype=REAL): + super(WordEmbeddingsKeyedVectors, self).__init__(vector_size=vector_size, dtype=REAL) self.vectors_norm = None self.index2word = [] @@ -1551,8 +1550,8 @@ def load(cls, fname_or_handle, **kwargs): class Doc2VecKeyedVectors(BaseKeyedVectors): - def __init__(self, vector_size, mapfile_path): - super(Doc2VecKeyedVectors, self).__init__(vector_size=vector_size) + def __init__(self, vector_size, mapfile_path, dtype=REAL): + super(Doc2VecKeyedVectors, self).__init__(vector_size=vector_size, dtype=REAL) self.doctags = {} # string -> Doctag (only filled if necessary) self.max_rawint = -1 # highest rawint-indexed doctag self.offset2doctag = [] # int offset-past-(max_rawint+1) -> String (only filled if necessary) diff --git a/gensim/models/poincare.py b/gensim/models/poincare.py index 295125d666..715d757a26 100644 --- a/gensim/models/poincare.py +++ b/gensim/models/poincare.py @@ -866,8 +866,8 @@ class PoincareKeyedVectors(BaseKeyedVectors): Used to perform operations on the vectors such as vector lookup, distance calculations etc. """ - def __init__(self, vector_size): - super(PoincareKeyedVectors, self).__init__(vector_size) + def __init__(self, vector_size, dtype=REAL): + super(PoincareKeyedVectors, self).__init__(vector_size, dtype=REAL) self.max_distance = 0 self.index2word = [] self.vocab = {} diff --git a/gensim/test/test_keyedvectors.py b/gensim/test/test_keyedvectors.py index 77e2ac07d3..3f0b0ae4f8 100644 --- a/gensim/test/test_keyedvectors.py +++ b/gensim/test/test_keyedvectors.py @@ -256,13 +256,14 @@ def test_add_multiple(self): self.assertTrue(np.allclose(kv[ent], vector)) def test_add_type(self): - kv = KeyedVectors(2) - words, vectors = ["a"], np.array([1., 1.], dtype=np.float32).reshape(1, -1) + dtype = np.float32 + kv = KeyedVectors(2, dtype=dtype) + words, vectors = ["a"], np.array([1., 1.], dtype=np.float16).reshape(1, -1) - assert kv.vectors.dtype == np.float64 # default dtype of empty KV + assert kv.vectors.dtype == dtype kv.add(words, vectors) - assert kv.vectors.dtype == np.float32 # new dtype of KV (copied from passed vectors) + assert kv.vectors.dtype == dtype def test_set_item(self): """Test that __setitem__ works correctly."""