Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gensim.models.BaseKeyedVectors.add_entity method for fill KeyedVectors in manual way. Fix #1942 #1957

60 changes: 60 additions & 0 deletions gensim/models/keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,66 @@ def get_vector(self, entity):
else:
raise KeyError("'%s' not in vocabulary" % entity)

def add(self, entities, weights, replace=False):
"""Add entities and theirs vectors in a manual way.
If some entity is already in the vocabulary, old vector is keeped unless `replace` flag is True.

Parameters
----------
entities : list of str
Entities specified by string tags.
weights: {list of numpy.ndarray, numpy.ndarray}
List of 1D np.array vectors or 2D np.array of vectors.
replace: bool, optional
Flag indicating whether to replace vectors for entities which are already in the vocabulary,
if True - replace vectors, otherwise - keep old vectors.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: multiline docstring should ends with empty line, i.e.

"""
...
last text

"""

if isinstance(entities, string_types):
entities = [entities]
weights = weights.reshape(1, -1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably, should be weights = np.array(weights).reshape(1, -1) for case if weights, for example, list of floats

elif isinstance(weights, list):
weights = np.array(weights)

in_vocab_mask = np.zeros(len(entities), dtype=np.bool)
for idx, entity in enumerate(entities):
if entity in self.vocab:
in_vocab_mask[idx] = True

# add new entities to the vocab
for idx in np.nonzero(~in_vocab_mask)[0]:
entity = entities[idx]
self.vocab[entity] = Vocab(index=len(self.vocab), count=1)
self.index2entity.append(entity)

# add vectors for new entities
if len(self.vectors) == 0:
self.vectors = weights[~in_vocab_mask]
else:
self.vectors = vstack((self.vectors, weights[~in_vocab_mask]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might this line work even in the case where len(self.vectors)==0, making the check/branch unnecessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's not obvious how to do that, because when empty KeyedVectors object is created, self.vectors = [] is true. In that case, we can't use vstack(([], weights[~in_vocab_mask])) and ValueError: all the input array dimensions except for the concatenation axis must match exactly is raised.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible for an empty KeyedVectors to have a self.vectors that is already a proper-dimensioned (0, vector_size) empty ndarray? (Not sure myself, but would simplify things in later places like this.)


# change vectors for in_vocab entities if `replace` flag is specified
if replace:
in_vocab_idxs = [self.vocab[entities[idx]].index for idx in np.nonzero(in_vocab_mask)[0]]
self.vectors[in_vocab_idxs] = weights[in_vocab_mask]

def __setitem__(self, entities, weights):
"""Add entities and theirs vectors in a manual way.
If some entity is already in the vocabulary, old vector is replaced with the new one.
This method is alias for `add` with `replace=True`.

Parameters
----------
entities : {str, list of str}
Entities specified by string tags.
weights: {list of numpy.ndarray, numpy.ndarray}
List of 1D np.array vectors or 2D np.array of vectors.
"""
if not isinstance(entities, list):
entities = [entities]
weights = weights.reshape(1, -1)

self.add(entities, weights, replace=True)

def __getitem__(self, entities):
"""
Accept a single entity (string tag) or list of entities as input.
Expand Down
72 changes: 72 additions & 0 deletions gensim/test/test_keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,78 @@ def test_wv_property(self):
"""Test that the deprecated `wv` property returns `self`. To be removed in v4.0.0."""
self.assertTrue(self.vectors is self.vectors.wv)

def test_add_single(self):
"""Test that adding entity in a manual way works correctly."""
entities = ['___some_entity{}_not_present_in_keyed_vectors___'.format(i) for i in range(5)]
vectors = [np.random.randn(self.vectors.vector_size) for _ in range(5)]

# Test `add` on already filled kv.
for ent, vector in zip(entities, vectors):
self.vectors.add(ent, vector)

for ent, vector in zip(entities, vectors):
self.assertTrue(np.allclose(self.vectors[ent], vector))

# Test `add` on empty kv.
kv = EuclideanKeyedVectors(self.vectors.vector_size)
for ent, vector in zip(entities, vectors):
kv.add(ent, vector)

for ent, vector in zip(entities, vectors):
self.assertTrue(np.allclose(kv[ent], vector))

def test_add_multiple(self):
"""Test that adding a bulk of entities in a manual way works correctly."""
entities = ['___some_entity{}_not_present_in_keyed_vectors___'.format(i) for i in range(5)]
vectors = [np.random.randn(self.vectors.vector_size) for _ in range(5)]

# Test `add` on already filled kv.
vocab_size = len(self.vectors.vocab)
self.vectors.add(entities, vectors, replace=False)
self.assertEqual(vocab_size + len(entities), len(self.vectors.vocab))

for ent, vector in zip(entities, vectors):
self.assertTrue(np.allclose(self.vectors[ent], vector))

# Test `add` on empty kv.
kv = EuclideanKeyedVectors(self.vectors.vector_size)
kv[entities] = vectors
self.assertEqual(len(kv.vocab), len(entities))

for ent, vector in zip(entities, vectors):
self.assertTrue(np.allclose(kv[ent], vector))

def test_set_item(self):
"""Test that __setitem__ works correctly."""
vocab_size = len(self.vectors.vocab)

# Add new entity.
entity = '___some_new_entity___'
vector = np.random.randn(self.vectors.vector_size)
self.vectors[entity] = vector

self.assertEqual(len(self.vectors.vocab), vocab_size + 1)
self.assertTrue(np.allclose(self.vectors[entity], vector))

# Replace vector for entity in vocab.
vocab_size = len(self.vectors.vocab)
vector = np.random.randn(self.vectors.vector_size)
self.vectors['war'] = vector

self.assertEqual(len(self.vectors.vocab), vocab_size)
self.assertTrue(np.allclose(self.vectors['war'], vector))

# __setitem__ on several entities.
vocab_size = len(self.vectors.vocab)
entities = ['war', '___some_new_entity1___', '___some_new_entity2___', 'terrorism', 'conflict']
vectors = [np.random.randn(self.vectors.vector_size) for _ in range(len(entities))]

self.vectors[entities] = vectors

self.assertEqual(len(self.vectors.vocab), vocab_size + 2)
for ent, vector in zip(entities, vectors):
self.assertTrue(np.allclose(self.vectors[ent], vector))


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
Expand Down