Skip to content

Commit

Permalink
Merge pull request #1327 from chinmayapancholi13/random_state_pr4
Browse files Browse the repository at this point in the history
[WIP] Fix backward incompatibility due to `random_state`
  • Loading branch information
menshikh-iv authored May 25, 2017
2 parents 7b6afc0 + c7194c9 commit 7414b60
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 9 deletions.
9 changes: 5 additions & 4 deletions gensim/models/ldamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,9 +1091,9 @@ def save(self, fname, ignore=['state', 'dispatcher'], separately=None, *args, **
separately_explicit = ['expElogbeta', 'sstats']
# Also add 'alpha' and 'eta' to separately list if they are set 'auto' or some
# array manually.
if (isinstance(self.alpha, six.string_types) and self.alpha == 'auto') or len(self.alpha.shape) != 1:
if (isinstance(self.alpha, six.string_types) and self.alpha == 'auto') or (isinstance(self.alpha, np.ndarray) and len(self.alpha.shape) != 1):
separately_explicit.append('alpha')
if (isinstance(self.eta, six.string_types) and self.eta == 'auto') or len(self.eta.shape) != 1:
if (isinstance(self.eta, six.string_types) and self.eta == 'auto') or (isinstance(self.eta, np.ndarray) and len(self.eta.shape) != 1):
separately_explicit.append('eta')
# Merge separately_explicit with separately.
if separately:
Expand All @@ -1117,6 +1117,9 @@ def load(cls, fname, *args, **kwargs):
"""
kwargs['mmap'] = kwargs.get('mmap', None)
result = super(LdaModel, cls).load(fname, *args, **kwargs)
if not hasattr(result, 'random_state'):
result.random_state = utils.get_random_state(None)
logging.warning("random_state not set so using default value")
state_fname = utils.smart_extension(fname, '.state')
try:
result.state = super(LdaModel, cls).load(state_fname, *args, **kwargs)
Expand All @@ -1128,7 +1131,5 @@ def load(cls, fname, *args, **kwargs):
result.id2word = utils.unpickle(id2word_fname)
except Exception as e:
logging.warning("failed to load id2word dictionary from %s: %s", id2word_fname, e)
else:
result.id2word = None
return result
# endclass LdaModel
Binary file added gensim/test/test_data/pre_0_13_2_model
Binary file not shown.
Binary file added gensim/test/test_data/pre_0_13_2_model.state
Binary file not shown.
35 changes: 30 additions & 5 deletions gensim/test/test_ldamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

def testfile(test_fname=''):
# temporary data will be stored to this file
fname = 'gensim_models_' + test_fname + '.tst'
fname = 'gensim_models_' + test_fname + '.tst'
return os.path.join(tempfile.gettempdir(), fname)


Expand Down Expand Up @@ -247,9 +247,9 @@ def testGetDocumentTopics(self):

#Test case to use the get_document_topic function for the corpus
all_topics = model.get_document_topics(self.corpus, per_word_topics=True)

self.assertEqual(model.state.numdocs, len(corpus))

for topic in all_topics:
self.assertTrue(isinstance(topic, tuple))
for k, v in topic[0]: # list of doc_topics
Expand All @@ -269,9 +269,9 @@ def testGetDocumentTopics(self):
word_phi_count_na = 0

all_topics = model.get_document_topics(self.corpus, minimum_probability=0.8, minimum_phi_value=1.0, per_word_topics=True)

self.assertEqual(model.state.numdocs, len(corpus))

for topic in all_topics:
self.assertTrue(isinstance(topic, tuple))
for k, v in topic[0]: # list of doc_topics
Expand Down Expand Up @@ -470,6 +470,31 @@ def testLargeMmapCompressed(self):
# test loading the large model arrays with mmap
self.assertRaises(IOError, self.class_.load, fname, mmap='r')

def testRandomStateBackwardCompatibility(self):
# load a model saved using a pre-0.13.2 version of Gensim
pre_0_13_2_fname = datapath('pre_0_13_2_model')
model_pre_0_13_2 = self.class_.load(pre_0_13_2_fname)

# set `num_topics` less than `model_pre_0_13_2.num_topics` so that `model_pre_0_13_2.random_state` is used
model_topics = model_pre_0_13_2.print_topics(num_topics=2, num_words=3)

for i in model_topics:
self.assertTrue(isinstance(i[0], int))
self.assertTrue(isinstance(i[1], six.string_types))

# save back the loaded model using a post-0.13.2 version of Gensim
post_0_13_2_fname = datapath('post_0_13_2_model')
model_pre_0_13_2.save(post_0_13_2_fname)

# load a model saved using a post-0.13.2 version of Gensim
model_post_0_13_2 = self.class_.load(post_0_13_2_fname)
model_topics_new = model_post_0_13_2.print_topics(num_topics=2, num_words=3)

for i in model_topics_new:
self.assertTrue(isinstance(i[0], int))
self.assertTrue(isinstance(i[1], six.string_types))


#endclass TestLdaModel


Expand Down

0 comments on commit 7414b60

Please sign in to comment.