Skip to content

Commit

Permalink
Vectorize word2vec.predict_output_word for speed (#3153)
Browse files Browse the repository at this point in the history
* [Fix] gensim/models/word2vec.py: in method predict_output_word, changed a call to sum to numpy.sum to gain performance.

* [Feat] gensim.models.word2vec.Word2Vec.predict_output_word: added possibility for the user to input a list of word indices as parameter 'context' instead of a list of words.

* Word2Vec.predict_output_word: Changed handling of ints and strs, trying to trying to make it more compact and versatile.

* Fixed docstring of predict_output_word.

* Simplified `predict_output_word` changes.

* Retained the suggested `sum`->`np.sum`
  replacement, which has been tested to
  yield significant runtime gains.
* Dropped unnecessary type/value checks
  that are already run when calling the
  `KeyedVectors.__isin__` dunder method.
* Corrected the docstring to accurately
  document the supported inputs (which
  were already compatible prior to the
  PR this commit is a part of).

* Added tests for gensim.Word2Vec.predict_output_word() when context contains ints.

* Update CHANGELOG.md

* update sbt install step

Co-authored-by: Mathis <mathis.demay@protonmail.com>
Co-authored-by: Paul Andrey <paul.andrey@hotmail.fr>
Co-authored-by: Mathis Demay <mathis.demay.etu@univ-lille.fr>
Co-authored-by: Michael Penkov <m@penkov.dev>
  • Loading branch information
5 people authored Jul 19, 2021
1 parent a93067d commit b287fd8
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 5 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ jobs:
#
- name: Update sbt
run: |
echo "deb https://dl.bintray.com/sbt/debian /" | sudo tee -a /etc/apt/sources.list.d/sbt.list
echo "deb https://repo.scala-sbt.org/scalasbt/debian all main" | sudo tee /etc/apt/sources.list.d/sbt.list
echo "deb https://repo.scala-sbt.org/scalasbt/debian /" | sudo tee /etc/apt/sources.list.d/sbt_old.list
curl -sL "https://keyserver.ubuntu.com/pks/lookup?op=get&search=0x2EE0EA64E40A89B84B2DF73499E82A75642AC823" | sudo apt-key add
sudo apt-get update -y
sudo apt-get install -y sbt
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Changes
* [#3115](https://github.com/RaRe-Technologies/gensim/pull/3115): Make LSI dispatcher CLI param for number of jobs optional, by [@robguinness](https://github.com/robguinness)
* [#3128](https://github.com/RaRe-Technologies/gensim/pull/3128): Materialize and copy the corpus passed to SoftCosineSimilarity, by [@Witiko](https://github.com/Witiko)
* [#3131](https://github.com/RaRe-Technologies/gensim/pull/3131): Added import to Nmf docs, and to models/__init__.py, by [@properGrammar](https://github.com/properGrammar)
* [#3153](https://github.com/RaRe-Technologies/gensim/pull/3153): Vectorize word2vec.predict_output_word for speed, by [@M-Demay](https://github.com/M-Demay)
* [#3157](https://github.com/RaRe-Technologies/gensim/pull/3157): New KeyedVectors.vectors_for_all method for vectorizing all words in a dictionary, by [@Witiko](https://github.com/Witiko)
* [#3163](https://github.com/RaRe-Technologies/gensim/pull/3163): Optimize word mover distance (WMD) computation, by [@flowlight0](https://github.com/flowlight0)
* [#2965](https://github.com/RaRe-Technologies/gensim/pull/2965): Remove strip_punctuation2 alias of strip_punctuation, by [@sciatro](https://github.com/sciatro)
Expand Down
9 changes: 5 additions & 4 deletions gensim/models/word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1806,8 +1806,9 @@ def predict_output_word(self, context_words_list, topn=10):
Parameters
----------
context_words_list : list of str
List of context words.
context_words_list : list of (str and/or int)
List of context words, which may be words themselves (str)
or their index in `self.wv.vectors` (int).
topn : int, optional
Return `topn` words and their probabilities.
Expand All @@ -1825,8 +1826,8 @@ def predict_output_word(self, context_words_list, topn=10):

if not hasattr(self.wv, 'vectors') or not hasattr(self, 'syn1neg'):
raise RuntimeError("Parameters required for predicting the output words not found.")

word2_indices = [self.wv.get_index(w) for w in context_words_list if w in self.wv]

if not word2_indices:
logger.warning("All the input context words are out-of-vocabulary for the current model.")
return None
Expand All @@ -1837,7 +1838,7 @@ def predict_output_word(self, context_words_list, topn=10):

# propagate hidden -> output and take softmax to get probabilities
prob_values = np.exp(np.dot(l1, self.syn1neg.T))
prob_values /= sum(prob_values)
prob_values /= np.sum(prob_values)
top_indices = matutils.argsort(prob_values, topn=topn, reverse=True)
# returning the most probable output words with their probabilities
return [(self.wv.index_to_key[index1], prob_values[index1]) for index1 in top_indices]
Expand Down
10 changes: 10 additions & 0 deletions gensim/test/test_word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,16 @@ def test_predict_output_word(self):
model_without_neg = word2vec.Word2Vec(sentences, min_count=1, negative=0)
self.assertRaises(RuntimeError, model_without_neg.predict_output_word, ['system', 'human'])

# passing indices instead of words in context
str_context = ['system', 'human']
mixed_context = [model_with_neg.wv.get_index(str_context[0]), str_context[1]]
idx_context = [model_with_neg.wv.get_index(w) for w in str_context]
prediction_from_str = model_with_neg.predict_output_word(str_context, topn=5)
prediction_from_mixed = model_with_neg.predict_output_word(mixed_context, topn=5)
prediction_from_idx = model_with_neg.predict_output_word(idx_context, topn=5)
self.assertEqual(prediction_from_str, prediction_from_mixed)
self.assertEqual(prediction_from_str, prediction_from_idx)

def test_load_old_model(self):
"""Test loading an old word2vec model of indeterminate version"""

Expand Down

0 comments on commit b287fd8

Please sign in to comment.