-
Notifications
You must be signed in to change notification settings - Fork 2
/
gensim_wrapper.py
91 lines (69 loc) · 3.84 KB
/
gensim_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import numpy as np
from gensim import corpora, models, matutils
from sklearn.base import BaseEstimator
class LdaTransformer(BaseEstimator):
"""
See http://radimrehurek.com/gensim/models/ldamodel.html for parameter usage.
X should be a list of tokens for each document, e.g. [['This', 'is', 'document', '1'], ['Second', 'document']]
"""
def __init__(self, n_latent_topics = 100, use_tfidf = False, distributed = False, chunksize = 2000, passes = 1, update_every = 1, alpha = 'symmetric', eta = None, decay = 0.5, eval_every = 10, iterations = 50, gamma_threshold = 0.001):
self.n_latent_topics = n_latent_topics
self.distributed = distributed
self.chunksize = chunksize
self.passes = passes
self.update_every = update_every
self.alpha = alpha
self.eta = eta
self.decay = decay
self.eval_every = eval_every
self.iterations = iterations
self.gamma_threshold = gamma_threshold
self.use_tfidf = use_tfidf
def transform(self, X):
corpus = [self.dictionary.doc2bow(text) for text in X]
if self.use_tfidf:
corpus = self.tfidf[corpus]
corpus_lda = self.model[corpus]
corpus_lda_dense = matutils.corpus2dense(corpus_lda, self.n_latent_topics).T
return corpus_lda_dense
def fit(self, X, y=None):
self.dictionary = corpora.Dictionary(X)
corpus = [self.dictionary.doc2bow(text) for text in X]
if self.use_tfidf:
self.tfidf = models.TfidfModel(corpus)
corpus = self.tfidf[corpus]
self.model = models.LdaModel(corpus, id2word = self.dictionary, num_topics = self.n_latent_topics, distributed = self.distributed, chunksize = self.chunksize, passes = self.passes, update_every = self.update_every, alpha = self.alpha, eta = self.eta, decay = self.decay, eval_every = self.eval_every, iterations = self.iterations, gamma_threshold = self.gamma_threshold)
return self
def get_params(self, deep = False):
return {'n_latent_topics': self.n_latent_topics, 'distributed': self.distributed, 'chunksize': self.chunksize, 'passes': self.passes, 'update_every': self.update_every, 'alpha': self.alpha, 'eta': self.eta, 'decay': self.decay, 'eval_every': self.eval_every, 'iterations': self.iterations, 'gamma_threshold': self.gamma_threshold}
class LsiTransformer(BaseEstimator):
"""
See http://radimrehurek.com/gensim/models/lsimodel.html for parameter usage.
X should be a list of tokens for each document, e.g. [['This', 'is', 'document', '1'], ['Second', 'document']]
"""
def __init__(self, n_latent_topics = 100, use_tfidf = True, chunksize = 20000, decay = 1.0, distributed = False, onepass = True, power_iters = 2, extra_samples = 100):
self.n_latent_topics = n_latent_topics
self.use_tfidf = use_tfidf
self.chunksize = chunksize
self.decay = decay
self.distributed = distributed
self.onepass = onepass
self.power_iters = power_iters
self.extra_samples = extra_samples
def transform(self, X):
corpus = [self.dictionary.doc2bow(text) for text in X]
if self.use_tfidf:
corpus = self.tfidf[corpus]
corpus_lsi = self.model[corpus]
corpus_lsi_dense = matutils.corpus2dense(corpus_lsi, self.n_latent_topics).T
return corpus_lsi_dense
def fit(self, X, y=None):
self.dictionary = corpora.Dictionary(X)
corpus = [self.dictionary.doc2bow(text) for text in X]
if self.use_tfidf:
self.tfidf = models.TfidfModel(corpus)
corpus = self.tfidf[corpus]
self.model = models.LsiModel(corpus, id2word = self.dictionary, num_topics = self.n_latent_topics, chunksize = self.chunksize, decay = self.decay, distributed = self.distributed, onepass = self.onepass, power_iters = self.power_iters, extra_samples = self.extra_samples)
return self
def get_params(self, deep = False):
return {'n_latent_topics': self.n_latent_topics, 'use_tfidf': self.use_tfidf, 'chunksize': self.chunksize, 'decay': self.decay, 'distributed': self.distributed, 'onepass': self.onepass, 'power_iters': self.power_iters, 'extra_samples': self.extra_samples}