Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Pivot Normalization for gensim.models.TfidfModel. Fix #220 #1780

Merged
merged 62 commits into from
Mar 13, 2018
Merged
Show file tree
Hide file tree
Changes from 48 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
efb7e3c
pivot normalization
markroxor Dec 12, 2017
b7d07d4
Merge branch 'develop' of https://github.com/RaRe-Technologies/gensim…
markroxor Dec 12, 2017
e8a3f16
verify weights
markroxor Dec 15, 2017
648bf21
verify weights
markroxor Dec 15, 2017
a6f1afb
smartirs ready
markroxor Dec 15, 2017
d091138
change old tests
markroxor Dec 15, 2017
951c549
remove lambdas
markroxor Dec 15, 2017
40c0558
address suggestions
markroxor Dec 16, 2017
b35344c
minor fix
markroxor Dec 19, 2017
634d595
pep8 fix
markroxor Dec 19, 2017
0917e75
pep8 fix
markroxor Dec 19, 2017
bef79cc
numpy style doc strings
markroxor Dec 19, 2017
d3d431c
fix pickle problem
menshikh-iv Dec 21, 2017
0e6f21e
flake8 fix
markroxor Dec 21, 2017
7ee7560
fix bug in docstring
menshikh-iv Dec 21, 2017
b2def84
added few tests
markroxor Dec 22, 2017
5b2d37a
fix normalize issue for pickling
markroxor Dec 22, 2017
ac4b154
fix normalize issue for pickling
markroxor Dec 22, 2017
0bacc08
test without sklearn api
markroxor Dec 22, 2017
51e0eb9
Merge branch 'smartirs' of github.com:markroxor/gensim into smartirs
markroxor Dec 22, 2017
3039732
hanging idents and new tests
markroxor Dec 25, 2017
99e6a6f
Merge branch 'develop' of https://github.com/RaRe-Technologies/gensim…
markroxor Dec 26, 2017
7d63d9c
Merge branch 'smartirs' of github.com:markroxor/gensim into smartirs
markroxor Dec 26, 2017
e5140f8
add docstring
markroxor Dec 26, 2017
4afbadd
add docstring
markroxor Dec 26, 2017
d2fe235
Merge branch 'smartirs' of github.com:markroxor/gensim into smartirs
markroxor Dec 26, 2017
5565c78
Merge branch 'develop' of https://github.com/RaRe-Technologies/gensim…
markroxor Dec 26, 2017
099dbdf
merge conflicts fix
markroxor Dec 26, 2017
ef67f63
pivotized normalization
markroxor Dec 26, 2017
52ee3c4
better way cmparing floats
markroxor Dec 27, 2017
3087030
pass tests
markroxor Dec 27, 2017
62bba1b
pass tests
markroxor Dec 27, 2017
0a9f816
merging
markroxor Feb 12, 2018
dc63ab9
Merge branch 'pivot_norm' of github.com:markroxor/gensim into pivot_norm
markroxor Feb 12, 2018
035c8c5
merge develop
markroxor Feb 12, 2018
dc4ca52
added benchmarks
markroxor Feb 12, 2018
1ee449d
address comments
markroxor Feb 15, 2018
4ea6caa
benchmarking
markroxor Feb 28, 2018
b3cead6
testing pipeline
markroxor Mar 1, 2018
044332b
pivoted normalisation
markroxor Mar 1, 2018
1c2196c
taking overall norm
markroxor Mar 3, 2018
309b4e8
Update tfidfmodel.py
markroxor Mar 3, 2018
3866a9c
Update sklearn_api.ipynb
markroxor Mar 3, 2018
12b42e6
tests for pivoted normalization
markroxor Mar 5, 2018
0ff6ad7
results
markroxor Mar 5, 2018
65c651b
adding visualizations
markroxor Mar 5, 2018
4a947ba
minor nb changes
markroxor Mar 6, 2018
619bb33
minor nb changes
markroxor Mar 6, 2018
f105190
removed self.pivoted_normalisation
markroxor Mar 11, 2018
6410f21
Update test_tfidfmodel.py
markroxor Mar 11, 2018
2eb6fc2
Merge branch 'develop' of https://github.com/RaRe-Technologies/gensim…
markroxor Mar 12, 2018
a65dccf
minor suggestions
markroxor Mar 12, 2018
8717350
added description
markroxor Mar 12, 2018
95cb630
added description
markroxor Mar 12, 2018
5f46d2f
Merge branch 'pivot_norm' of github.com:markroxor/gensim into pivot_norm
markroxor Mar 13, 2018
2c7115d
last commit
markroxor Mar 13, 2018
1fe46f8
Merge remote-tracking branch 'upstream/develop' into pivot_norm
menshikh-iv Mar 13, 2018
63c8385
cleanup
menshikh-iv Mar 13, 2018
5e87229
cosmetic fixes
menshikh-iv Mar 13, 2018
9f2b02c
changed pivot
markroxor Mar 13, 2018
fc701a1
changed pivot
markroxor Mar 13, 2018
1868da5
fixed comments
markroxor Mar 13, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/notebooks/line.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
473 changes: 473 additions & 0 deletions docs/notebooks/pivoted_document_length_normalisation.ipynb

Large diffs are not rendered by default.

27 changes: 21 additions & 6 deletions gensim/matutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ def ret_log_normalize_vec(vec, axis=1):
blas_scal = blas('scal', np.array([], dtype=float))


def unitvec(vec, norm='l2'):
def unitvec(vec, norm='l2', return_norm=False):
mpenkov marked this conversation as resolved.
Show resolved Hide resolved
"""Scale a vector to unit length.

Parameters
Expand Down Expand Up @@ -695,9 +695,15 @@ def unitvec(vec, norm='l2'):
if norm == 'l2':
veclen = np.sqrt(np.sum(vec.data ** 2))
if veclen > 0.0:
return vec / veclen
if return_norm:
return vec / veclen, veclen
else:
return vec / veclen
else:
return vec
if return_norm:
return vec, 1
else:
return vec

if isinstance(vec, np.ndarray):
vec = np.asarray(vec, dtype=float)
Expand All @@ -706,9 +712,15 @@ def unitvec(vec, norm='l2'):
if norm == 'l2':
veclen = blas_nrm2(vec)
if veclen > 0.0:
return blas_scal(1.0 / veclen, vec)
if return_norm:
return blas_scal(1.0 / veclen, vec), veclen
else:
return blas_scal(1.0 / veclen, vec)
else:
return vec
if return_norm:
return vec, 1
else:
return vec

try:
first = next(iter(vec)) # is there at least one element?
Expand All @@ -721,7 +733,10 @@ def unitvec(vec, norm='l2'):
if norm == 'l2':
length = 1.0 * math.sqrt(sum(val ** 2 for _, val in vec))
assert length > 0.0, "sparse documents must not contain any explicit zero entries"
return ret_normalized_vec(vec, length)
if return_norm:
return ret_normalized_vec(vec, length), length
else:
return ret_normalized_vec(vec, length)
else:
raise ValueError("unknown input type")

Expand Down
84 changes: 72 additions & 12 deletions gensim/models/tfidfmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from gensim import interfaces, matutils, utils
from six import iteritems

from scipy import sparse as sp
import numpy as np

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -70,7 +71,7 @@ def resolve_weights(smartirs):
if w_df not in 'ntp':
raise ValueError("Expected inverse document frequency weight to be one of 'ntp', except got {}".format(w_df))

if w_n not in 'ncb':
if w_n not in 'nc':
raise ValueError("Expected normalization weight to be one of 'ncb', except got {}".format(w_n))

return w_tf, w_df, w_n
Expand Down Expand Up @@ -177,7 +178,7 @@ def updated_wglobal(docfreq, totaldocs, n_df):
return np.log((1.0 * totaldocs - docfreq) / docfreq) / np.log(2)


def updated_normalize(x, n_n):
def updated_normalize(x, n_n, return_norm=False):
mpenkov marked this conversation as resolved.
Show resolved Hide resolved
"""Normalizes the final tf-idf value according to the value of `n_n`.

Parameters
Expand All @@ -194,9 +195,12 @@ def updated_normalize(x, n_n):

"""
if n_n == "n":
return x
if return_norm:
return x, 1
else:
return x
elif n_n == "c":
return matutils.unitvec(x)
return matutils.unitvec(x, return_norm=return_norm)


class TfidfModel(interfaces.TransformationABC):
Expand All @@ -219,7 +223,8 @@ class TfidfModel(interfaces.TransformationABC):
"""

def __init__(self, corpus=None, id2word=None, dictionary=None, wlocal=utils.identity,
wglobal=df2idf, normalize=True, smartirs=None):
wglobal=df2idf, normalize=True, smartirs=None,
pivot_norm=False, slope=0.65, pivot=None):
"""Compute tf-idf by multiplying a local component (term frequency) with a global component
(inverse document frequency), and normalizing the resulting documents to unit length.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed docstring for new parameters

Formula for non-normalized weight of term :math:`i` in document :math:`j` in a corpus of :math:`D` documents
Expand Down Expand Up @@ -273,21 +278,34 @@ def __init__(self, corpus=None, id2word=None, dictionary=None, wlocal=utils.iden
* `c` - cosine.

For more information visit [1]_.

pivot_norm : bool, optional
If pivot_norm is True, then pivoted document length normalization will be applied.
slope : float, optional
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can confuse users, need to mention that works only if pivot set.

It is the parameter required by pivoted document length normalization which determines the slope to which
the `old normalization` can be tilted.
pivot : int/float, optional
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float

Pivot is the point before which we consider a document to be short and after which the document is
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

too "broad" description, next question will be "what is retrieval and relevence curves" and "how to plot it"

considered long. It can be found by plotting the retrieval and relevence curves of a set of documents using
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, you can add latex formula here using http://www.sphinx-doc.org/en/master/ext/math.html

a general normalization function. The point where both these curves coincide is the pivot point.
"""

self.id2word = id2word
self.wlocal, self.wglobal, self.normalize = wlocal, wglobal, normalize
self.num_docs, self.num_nnz, self.idfs = None, None, None
self.smartirs = smartirs
self.pivot_norm = pivot_norm
self.slope = slope
self.pivot = pivot
self.eps = 1e-12

# If smartirs is not None, override wlocal, wglobal and normalize
if smartirs is not None:
n_tf, n_df, n_n = resolve_weights(smartirs)

self.wlocal = partial(updated_wlocal, n_tf=n_tf)
self.wglobal = partial(updated_wglobal, n_df=n_df)
self.normalize = partial(updated_normalize, n_n=n_n)
# also return norm factor if pivot_norm is True
self.normalize = partial(updated_normalize, n_n=n_n, return_norm=self.pivot_norm)

if dictionary is not None:
# user supplied a Dictionary object, which already contains all the
Expand All @@ -309,6 +327,19 @@ def __init__(self, corpus=None, id2word=None, dictionary=None, wlocal=utils.iden
# be initialized in some other way
pass

@classmethod
def load(cls, *args, **kwargs):
"""
Load a previously saved TfidfModel class. Handles backwards compatibility from
older TfidfModel versions which did not use pivoted document normalization.
"""
model = super(TfidfModel, cls).load(*args, **kwargs)
if not hasattr(model, 'pivot_norm'):
logger.info('older version of %s loaded without pivot_norm arg', cls.__name__)
logger.info('Setting pivot_norm to False.')
model.pivot_norm = False
return model

def __str__(self):
return "TfidfModel(num_docs=%s, num_nnz=%s)" % (self.num_docs, self.num_nnz)

Expand Down Expand Up @@ -360,6 +391,7 @@ def __getitem__(self, bow, eps=1e-12):
TfIdf corpus, if `bow` is corpus.

"""
self.eps = eps
# if the input vector is in fact a corpus, return a transformed corpus as a result
is_corpus, bow = utils.is_corpus(bow)
if is_corpus:
Expand All @@ -377,7 +409,7 @@ def __getitem__(self, bow, eps=1e-12):

vector = [
(termid, tf * self.idfs.get(termid))
for termid, tf in zip(termid_array, tf_array) if abs(self.idfs.get(termid, 0.0)) > eps
for termid, tf in zip(termid_array, tf_array) if abs(self.idfs.get(termid, 0.0)) > self.eps
]

if self.normalize is True:
Expand All @@ -387,8 +419,36 @@ def __getitem__(self, bow, eps=1e-12):

# and finally, normalize the vector either to unit length, or use a
# user-defined normalization function
vector = self.normalize(vector)
if self.pivot_norm is False:
norm_vector = self.normalize(vector)
norm_vector = [(termid, weight) for termid, weight in norm_vector if abs(weight) > self.eps]
return norm_vector
else:
logger.info("You need to explicitly call pivoted_normalization.")
return vector

def pivoted_normalization(self, tfidf_matrix):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ This is the most suspicious place, need to investigate, how this can be done in different way ⚠️

X = matutils.corpus2csc(tfidf_matrix).T
n_samples, n_features = X.shape
X_norm = []

for vec in tfidf_matrix:
_, norm = self.normalize(vec, return_norm=True)
X_norm.append(norm)

X_norm = np.array(X_norm)

if self.pivot is None:
self.pivot = X_norm.mean()

pivoted_norm = (1 - self.slope) * self.pivot + self.slope * X_norm
_diag_pivoted_norm = sp.spdiags(1. / pivoted_norm, diags=0, m=n_samples,
n=n_samples, format='csr')
X = _diag_pivoted_norm.dot(X)

norm_vector = []

# make sure there are no explicit zeroes in the vector (must be sparse)
vector = [(termid, weight) for termid, weight in vector if abs(weight) > eps]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we used 140 characters limit -> no need to split this line here

return vector
X = matutils.Scipy2Corpus(X)
for doc in X:
norm_vector.append([(termid, weight) for termid, weight in doc if abs(weight) > self.eps])
return norm_vector
14 changes: 12 additions & 2 deletions gensim/sklearn_api/tfidf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ class TfIdfTransformer(TransformerMixin, BaseEstimator):
"""

def __init__(self, id2word=None, dictionary=None, wlocal=gensim.utils.identity,
wglobal=gensim.models.tfidfmodel.df2idf, normalize=True, smartirs="ntc"):
wglobal=gensim.models.tfidfmodel.df2idf, normalize=True, smartirs="ntc",
pivot_norm=False, slope=0.65, pivot=None):
"""
Sklearn wrapper for Tf-Idf model.
"""
Expand All @@ -33,6 +34,9 @@ def __init__(self, id2word=None, dictionary=None, wlocal=gensim.utils.identity,
self.wglobal = wglobal
self.normalize = normalize
self.smartirs = smartirs
self.pivot_norm = pivot_norm
self.slope = slope
self.pivot = pivot

def fit(self, X, y=None):
"""
Expand All @@ -41,6 +45,7 @@ def fit(self, X, y=None):
self.gensim_model = TfidfModel(
corpus=X, id2word=self.id2word, dictionary=self.dictionary, wlocal=self.wlocal,
wglobal=self.wglobal, normalize=self.normalize, smartirs=self.smartirs,
pivot_norm=self.pivot_norm, slope=self.slope, pivot=self.pivot
)
return self

Expand All @@ -56,4 +61,9 @@ def transform(self, docs):
# input as python lists
if isinstance(docs[0], tuple):
docs = [docs]
return [self.gensim_model[doc] for doc in docs]

tfidf_matrix = [self.gensim_model[doc] for doc in docs]
if self.pivot_norm is True:
return self.gensim_model.pivoted_normalization(tfidf_matrix)
else:
return tfidf_matrix
31 changes: 30 additions & 1 deletion gensim/test/test_tfidfmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def testPersistenceCompressed(self):
self.assertTrue(np.allclose(model3[tstvec[1]], model4[tstvec[1]]))
self.assertTrue(np.allclose(model3[[]], model4[[]])) # try projecting an empty vector

def TestConsistency(self):
def testConsistency(self):
docs = [corpus[1], corpus[2]]

# Test if `ntc` yields the default docs.
Expand Down Expand Up @@ -283,6 +283,35 @@ def TestConsistency(self):
self.assertTrue(np.allclose(transformed_docs[0], expected_docs[0]))
self.assertTrue(np.allclose(transformed_docs[1], expected_docs[1]))

def testPivotedNormalization(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add test (load old model with new code, same as for SMART feature)

docs = [corpus[1], corpus[2]]

# Test if slope=1 yields the default docs for pivoted normalization.
model = tfidfmodel.TfidfModel(self.corpus)
transformed_docs = [model[docs[0]], model[docs[1]]]

model = tfidfmodel.TfidfModel(self.corpus, slope=1, pivot_norm=True)
expected_docs = model.pivoted_normalization([model[docs[0]], model[docs[1]]])

self.assertTrue(np.allclose(sorted(transformed_docs[0]), sorted(expected_docs[0])))
self.assertTrue(np.allclose(sorted(transformed_docs[1]), sorted(expected_docs[1])))

# Test if pivoted model is consistent
model = tfidfmodel.TfidfModel(self.corpus, slope=0.5, pivot_norm=True)
transformed_docs = model.pivoted_normalization([model[docs[0]], model[docs[1]]])
expected_docs = [[(8, 0.4682642467547897),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use hanging indents

(7, 0.34203084025552255),
(6, 0.4682642467547897),
(5, 0.34203084025552255),
(4, 0.4682642467547897),
(3, 0.4682642467547897)],
[(10, 0.3834996737115108),
(9, 0.3834996737115108),
(5, 0.7669993474230216)]]

self.assertTrue(np.allclose(sorted(transformed_docs[0]), sorted(expected_docs[0])))
self.assertTrue(np.allclose(sorted(transformed_docs[1]), sorted(expected_docs[1])))


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
Expand Down