Skip to content

Commit

Permalink
fix normalize issue for pickling
Browse files Browse the repository at this point in the history
  • Loading branch information
markroxor committed Dec 22, 2017
1 parent b2def84 commit ac4b154
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 172 deletions.
28 changes: 12 additions & 16 deletions gensim/models/tfidfmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def precompute_idfs(wglobal, dfs, total_docs):
return {termid: wglobal(df, total_docs) for termid, df in iteritems(dfs)}


def wlocal_g(tf, n_tf): # TODO rename it (to avoid confusion)
def updated_wlocal(tf, n_tf):
if n_tf == "n":
return tf
elif n_tf == "l":
Expand All @@ -67,7 +67,7 @@ def wlocal_g(tf, n_tf): # TODO rename it (to avoid confusion)
return (1 + np.log(tf) / np.log(2)) / (1 + np.log(tf.mean(axis=0) / np.log(2)))


def wglobal_g(docfreq, totaldocs, n_df): # TODO rename it (to avoid confusion)
def updated_wglobal(docfreq, totaldocs, n_df): # TODO rename it (to avoid confusion)
if n_df == "n":
return utils.identity(docfreq)
elif n_df == "t":
Expand All @@ -76,14 +76,11 @@ def wglobal_g(docfreq, totaldocs, n_df): # TODO rename it (to avoid confusion)
return np.log((1.0 * totaldocs - docfreq) / docfreq) / np.log(2)


def normalize_g(x, n_n): # TODO rename it (to avoid confusion)
def updated_normalize(x, n_n): # TODO rename it (to avoid confusion)
if n_n == "n":
return x
elif n_n == "c":
return matutils.unitvec(x)
# TODO write byte-size normalisation
# elif n_n == "b":
# pass


class TfidfModel(interfaces.TransformationABC):
Expand Down Expand Up @@ -152,7 +149,7 @@ def __init__(self, corpus=None, id2word=None, dictionary=None, wlocal=utils.iden
Document frequency weighting:
none - `n`, idf - `t`, prob idf - `p`.
Document normalization:
none - `n`, cosine - `c`, byte size - `b`.
none - `n`, cosine - `c`.
for more information visit https://en.wikipedia.org/wiki/SMART_Information_Retrieval_System
Expand All @@ -167,18 +164,13 @@ def __init__(self, corpus=None, id2word=None, dictionary=None, wlocal=utils.iden
self.num_docs, self.num_nnz, self.idfs = None, None, None
self.smartirs = smartirs

if self.normalize is True:
self.normalize = matutils.unitvec
elif self.normalize is False:
self.normalize = utils.identity

# If smartirs is not None, override wlocal, wglobal and normalize
if smartirs is not None:
n_tf, n_df, n_n = resolve_weights(smartirs)

self.wlocal = partial(wlocal_g, n_tf=n_tf)
self.wglobal = partial(wglobal_g, n_df=n_df)
self.normalize = partial(normalize_g, n_n=n_n)
self.wlocal = partial(updated_wlocal, n_tf=n_tf)
self.wglobal = partial(updated_wglobal, n_df=n_df)
self.normalize = partial(updated_normalize, n_n=n_n)

if dictionary is not None:
# user supplied a Dictionary object, which already contains all the
Expand Down Expand Up @@ -255,9 +247,13 @@ def __getitem__(self, bow, eps=1e-12):
for termid, tf in zip(termid_array, tf_array) if self.idfs.get(termid, 0.0) != 0.0
]

if self.normalize is True:
self.normalize = matutils.unitvec
elif self.normalize is False:
self.normalize = utils.identity

# and finally, normalize the vector either to unit length, or use a
# user-defined normalization function

vector = self.normalize(vector)

# make sure there are no explicit zeroes in the vector (must be sparse)
Expand Down
Loading

0 comments on commit ac4b154

Please # to comment.