Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fix backward incompatibility due to random_state #1327

Merged
merged 11 commits into from
May 25, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions gensim/models/ldamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,9 +1023,9 @@ def save(self, fname, ignore=['state', 'dispatcher'], separately=None, *args, **
separately_explicit = ['expElogbeta', 'sstats']
# Also add 'alpha' and 'eta' to separately list if they are set 'auto' or some
# array manually.
if (isinstance(self.alpha, six.string_types) and self.alpha == 'auto') or len(self.alpha.shape) != 1:
if (isinstance(self.alpha, six.string_types) and self.alpha == 'auto') or (isinstance(self.alpha, np.ndarray) and len(self.alpha.shape) != 1):
separately_explicit.append('alpha')
if (isinstance(self.eta, six.string_types) and self.eta == 'auto') or len(self.eta.shape) != 1:
if (isinstance(self.eta, six.string_types) and self.eta == 'auto') or (isinstance(self.eta, np.ndarray) and len(self.eta.shape) != 1):
separately_explicit.append('eta')
# Merge separately_explicit with separately.
if separately:
Expand All @@ -1049,6 +1049,9 @@ def load(cls, fname, *args, **kwargs):
"""
kwargs['mmap'] = kwargs.get('mmap', None)
result = super(LdaModel, cls).load(fname, *args, **kwargs)
if not hasattr(result, 'random_state'):
result.random_state = utils.get_random_state(None)
logging.warning("random_state not set so using default value")
state_fname = utils.smart_extension(fname, '.state')
try:
result.state = super(LdaModel, cls).load(state_fname, *args, **kwargs)
Expand All @@ -1060,7 +1063,5 @@ def load(cls, fname, *args, **kwargs):
result.id2word = utils.unpickle(id2word_fname)
except Exception as e:
logging.warning("failed to load id2word dictionary from %s: %s", id2word_fname, e)
else:
result.id2word = None
return result
# endclass LdaModel
Binary file added gensim/test/test_data/pre_0_13_2_model
Binary file not shown.
Binary file added gensim/test/test_data/pre_0_13_2_model.state
Binary file not shown.
35 changes: 30 additions & 5 deletions gensim/test/test_ldamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

def testfile(test_fname=''):
# temporary data will be stored to this file
fname = 'gensim_models_' + test_fname + '.tst'
fname = 'gensim_models_' + test_fname + '.tst'
return os.path.join(tempfile.gettempdir(), fname)


Expand Down Expand Up @@ -247,9 +247,9 @@ def testGetDocumentTopics(self):

#Test case to use the get_document_topic function for the corpus
all_topics = model.get_document_topics(self.corpus, per_word_topics=True)

self.assertEqual(model.state.numdocs, len(corpus))

for topic in all_topics:
self.assertTrue(isinstance(topic, tuple))
for k, v in topic[0]: # list of doc_topics
Expand All @@ -269,9 +269,9 @@ def testGetDocumentTopics(self):
word_phi_count_na = 0

all_topics = model.get_document_topics(self.corpus, minimum_probability=0.8, minimum_phi_value=1.0, per_word_topics=True)

self.assertEqual(model.state.numdocs, len(corpus))

for topic in all_topics:
self.assertTrue(isinstance(topic, tuple))
for k, v in topic[0]: # list of doc_topics
Expand Down Expand Up @@ -470,6 +470,31 @@ def testLargeMmapCompressed(self):
# test loading the large model arrays with mmap
self.assertRaises(IOError, self.class_.load, fname, mmap='r')

def testRandomStateBackwardCompatibility(self):
# load a model saved using a pre-0.13.2 version of Gensim
pre_0_13_2_fname = datapath('pre_0_13_2_model')
model_pre_0_13_2 = self.class_.load(pre_0_13_2_fname)

# set `num_topics` less than `model_pre_0_13_2.num_topics` so that `model_pre_0_13_2.random_state` is used
model_topics = model_pre_0_13_2.print_topics(num_topics=2, num_words=3)

for i in model_topics:
self.assertTrue(isinstance(i[0], int))
self.assertTrue(isinstance(i[1], six.string_types))

# save back the loaded model using a post-0.13.2 version of Gensim
post_0_13_2_fname = datapath('post_0_13_2_model')
model_pre_0_13_2.save(post_0_13_2_fname)

# load a model saved using a post-0.13.2 version of Gensim
model_post_0_13_2 = self.class_.load(post_0_13_2_fname)
model_topics_new = model_post_0_13_2.print_topics(num_topics=2, num_words=3)

for i in model_topics_new:
self.assertTrue(isinstance(i[0], int))
self.assertTrue(isinstance(i[1], six.string_types))


#endclass TestLdaModel


Expand Down