Skip to content

Commit 2430cca

Browse files
committed
Attempting to fix slow NaiveBayes
Three changes: 1) basic_extractor can accept a list of strings as well as a list of ('word','label') tuples. 2) BaseClassifier now has an instance variable _word_set which is a set of tokens seen by the classifier. 1+2) BaseClassifier.extract_features passes _word_set to extractor rather than the training set. 3) NLTKClassifier.update adds new words to the _word_set.
1 parent ca0555a commit 2430cca

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

textblob/classifiers.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,15 @@ def basic_extractor(document, train_set):
7676
7777
:param document: The text to extract features from. Can be a string or an iterable.
7878
:param list train_set: Training data set, a list of tuples of the form
79-
``(words, label)``.
79+
``(words, label)`` OR an iterable of strings.
8080
"""
81-
word_features = _get_words_from_dataset(train_set)
81+
el_zero = iter(train_set).next() #Infer input from first element.
82+
if isinstance(el_zero, tuple):
83+
word_features = _get_words_from_dataset(train_set)
84+
elif isinstance(el_zero, str):
85+
word_features = train_set
86+
else:
87+
raise ValueError('train_set is proabably malformed.')
8288
tokens = _get_document_tokens(document)
8389
features = dict(((u'contains({0})'.format(word), (word in tokens))
8490
for word in word_features))
@@ -123,6 +129,7 @@ def __init__(self, train_set, feature_extractor=basic_extractor, format=None, **
123129
self.train_set = self._read_data(train_set, format)
124130
else: # train_set is a list of tuples
125131
self.train_set = train_set
132+
self._word_set = _get_words_from_dataset(train_set) #Keep a hidden set of unique words.
126133
self.train_features = None
127134

128135
def _read_data(self, dataset, format=None):
@@ -166,7 +173,7 @@ def extract_features(self, text):
166173
'''
167174
# Feature extractor may take one or two arguments
168175
try:
169-
return self.feature_extractor(text, self.train_set)
176+
return self.feature_extractor(text, self._word_set)
170177
except (TypeError, AttributeError):
171178
return self.feature_extractor(text)
172179

@@ -260,6 +267,7 @@ def update(self, new_data, *args, **kwargs):
260267
``(text, label)``.
261268
"""
262269
self.train_set += new_data
270+
self._word_set.update(_get_words_from_dataset(new_data))
263271
self.train_features = [(self.extract_features(d), c)
264272
for d, c in self.train_set]
265273
try:

0 commit comments

Comments
 (0)