Skip to content

Commit

Permalink
parametrize datatype & cast embeddings passed to add to KV datatype
Browse files Browse the repository at this point in the history
  • Loading branch information
menshikh-iv committed Mar 1, 2020
1 parent 38daff5 commit e031fc4
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 14 deletions.
15 changes: 7 additions & 8 deletions gensim/models/keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ def __str__(self):

class BaseKeyedVectors(utils.SaveLoad):
"""Abstract base class / interface for various types of word vectors."""
def __init__(self, vector_size):
self.vectors = zeros((0, vector_size))
def __init__(self, vector_size, dtype=REAL):
self.vectors = zeros((0, vector_size), dtype=dtype)
self.vocab = {}
self.vector_size = vector_size
self.index2entity = []
Expand Down Expand Up @@ -308,8 +308,7 @@ def add(self, entities, weights, replace=False):
self.index2entity.append(entity)

# add vectors for new entities
self.vectors = self.vectors.astype(weights.dtype) # cast existing vectors to 'weights' type
self.vectors = vstack((self.vectors, weights[~in_vocab_mask]))
self.vectors = vstack((self.vectors, weights[~in_vocab_mask].astype(self.vectors.dtype)))

# change vectors for in_vocab entities if `replace` flag is specified
if replace:
Expand Down Expand Up @@ -377,8 +376,8 @@ def rank(self, entity1, entity2):

class WordEmbeddingsKeyedVectors(BaseKeyedVectors):
"""Class containing common methods for operations over word vectors."""
def __init__(self, vector_size):
super(WordEmbeddingsKeyedVectors, self).__init__(vector_size=vector_size)
def __init__(self, vector_size, dtype=REAL):
super(WordEmbeddingsKeyedVectors, self).__init__(vector_size=vector_size, dtype=REAL)
self.vectors_norm = None
self.index2word = []

Expand Down Expand Up @@ -1551,8 +1550,8 @@ def load(cls, fname_or_handle, **kwargs):

class Doc2VecKeyedVectors(BaseKeyedVectors):

def __init__(self, vector_size, mapfile_path):
super(Doc2VecKeyedVectors, self).__init__(vector_size=vector_size)
def __init__(self, vector_size, mapfile_path, dtype=REAL):
super(Doc2VecKeyedVectors, self).__init__(vector_size=vector_size, dtype=REAL)
self.doctags = {} # string -> Doctag (only filled if necessary)
self.max_rawint = -1 # highest rawint-indexed doctag
self.offset2doctag = [] # int offset-past-(max_rawint+1) -> String (only filled if necessary)
Expand Down
4 changes: 2 additions & 2 deletions gensim/models/poincare.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,8 +866,8 @@ class PoincareKeyedVectors(BaseKeyedVectors):
Used to perform operations on the vectors such as vector lookup, distance calculations etc.
"""
def __init__(self, vector_size):
super(PoincareKeyedVectors, self).__init__(vector_size)
def __init__(self, vector_size, dtype=REAL):
super(PoincareKeyedVectors, self).__init__(vector_size, dtype=REAL)
self.max_distance = 0
self.index2word = []
self.vocab = {}
Expand Down
9 changes: 5 additions & 4 deletions gensim/test/test_keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,14 @@ def test_add_multiple(self):
self.assertTrue(np.allclose(kv[ent], vector))

def test_add_type(self):
kv = KeyedVectors(2)
words, vectors = ["a"], np.array([1., 1.], dtype=np.float32).reshape(1, -1)
dtype = np.float32
kv = KeyedVectors(2, dtype=dtype)
words, vectors = ["a"], np.array([1., 1.], dtype=np.float16).reshape(1, -1)

assert kv.vectors.dtype == np.float64 # default dtype of empty KV
assert kv.vectors.dtype == dtype

kv.add(words, vectors)
assert kv.vectors.dtype == np.float32 # new dtype of KV (copied from passed vectors)
assert kv.vectors.dtype == dtype

def test_set_item(self):
"""Test that __setitem__ works correctly."""
Expand Down

0 comments on commit e031fc4

Please sign in to comment.