Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed KeyError in coherence model #2830

Merged
merged 13 commits into from
Jun 29, 2021
16 changes: 10 additions & 6 deletions gensim/models/coherencemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,12 +441,16 @@ def topics(self, topics):
self._topics = new_topics

def _ensure_elements_are_ids(self, topic):
try:
return np.array([self.dictionary.token2id[token] for token in topic])
except KeyError: # might be a list of token ids already, but let's verify all in dict
topic = (self.dictionary.id2token[_id] for _id in topic)
return np.array([self.dictionary.token2id[token] for token in topic])

tokens = [t for t in topic if t in self.dictionary.token2id]
pietrotrope marked this conversation as resolved.
Show resolved Hide resolved
elements_are_tokens = np.array([self.dictionary.token2id[token] for token in tokens])
elements_are_ids = np.array([i for i in topic if i in self.dictionary.id2token])
if elements_are_tokens.size > elements_are_ids.size:
return elements_are_tokens
elif elements_are_ids.size > elements_are_tokens.size:
return elements_are_ids
else:
raise Exception("Topic list is not a list of lists of tokens or ids")
pietrotrope marked this conversation as resolved.
Show resolved Hide resolved

def _update_accumulator(self, new_topics):
if self._relevant_ids_will_differ(new_topics):
logger.debug("Wiping cached accumulator since it does not contain all relevant ids.")
Expand Down
7 changes: 7 additions & 0 deletions gensim/test/test_coherencemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def setUp(self):
['user', 'graph', 'minors', 'system'],
['time', 'graph', 'survey', 'minors']
]
self.topics3 = [
['human', 'computer', 'system', 'interface'],
['graph', 'minors', 'trees', 'eps']
]
pietrotrope marked this conversation as resolved.
Show resolved Hide resolved

self.ldamodel = LdaModel(
corpus=self.corpus, id2word=self.dictionary, num_topics=2,
passes=0, iterations=0
Expand Down Expand Up @@ -79,6 +84,8 @@ def check_coherence_measure(self, coherence):

cm1 = CoherenceModel(topics=self.topics1, **kwargs)
cm2 = CoherenceModel(topics=self.topics2, **kwargs)
cm3 = CoherenceModel(topics=self.topics3, **kwargs)
self.assertIsInstance(cm3.get_coherence(), np.double)
self.assertGreater(cm1.get_coherence(), cm2.get_coherence())

def testUMass(self):
Expand Down