1212from sklearn .externals import six
1313from sklearn .tree import _tree
1414
15+ from .rule import Rule
16+
1517INTEGER_TYPES = (numbers .Integral , np .integer )
1618
1719
@@ -205,7 +207,8 @@ def fit(self, X, y, sample_weight=None):
205207 " in the data, but the data contains only one"
206208 " class: %r" % self .classes_ [0 ])
207209
208- if not isinstance (self .max_depth_duplication , int ) and self .max_depth_duplication is not None :
210+ if not isinstance (self .max_depth_duplication , int ) \
211+ and self .max_depth_duplication is not None :
209212 raise ValueError ("max_depth_duplication should be an integer"
210213 )
211214 if not set (self .classes_ ) == set ([0 , 1 ]):
@@ -265,7 +268,8 @@ def fit(self, X, y, sample_weight=None):
265268 max_features = self .max_samples_features ,
266269 bootstrap = self .bootstrap ,
267270 bootstrap_features = self .bootstrap_features ,
268- # oob_score=... XXX may be added if selection on tree perf needed.
271+ # oob_score=... XXX may be added
272+ # if selection on tree perf needed.
269273 # warm_start=... XXX may be added to increase computation perf.
270274 n_jobs = self .n_jobs ,
271275 random_state = self .random_state ,
@@ -281,7 +285,8 @@ def fit(self, X, y, sample_weight=None):
281285 max_features = self .max_samples_features ,
282286 bootstrap = self .bootstrap ,
283287 bootstrap_features = self .bootstrap_features ,
284- # oob_score=... XXX may be added if selection on tree perf needed.
288+ # oob_score=... XXX may be added
289+ # if selection on tree perf needed.
285290 # warm_start=... XXX may be added to increase computation perf.
286291 n_jobs = self .n_jobs ,
287292 random_state = self .random_state ,
@@ -345,6 +350,12 @@ def fit(self, X, y, sample_weight=None):
345350 for r in set (rules_from_tree )]
346351 rules_ += rules_from_tree
347352
353+ # Factorize rules before semantic tree filtering
354+ rules_ = [
355+ tuple (rule )
356+ for rule in
357+ [Rule (r , args = args ) for r , args in rules_ ]]
358+
348359 # keep only rules verifying precision_min and recall_min:
349360 for rule , score in rules_ :
350361 if score [0 ] >= self .precision_min and score [1 ] >= self .recall_min :
@@ -363,7 +374,7 @@ def fit(self, X, y, sample_weight=None):
363374 self .rules_ = sorted (self .rules_ .items (),
364375 key = lambda x : (x [1 ][0 ], x [1 ][1 ]), reverse = True )
365376
366- # count representation of feature
377+ # Deduplicate the rule using semantic tree
367378 if self .max_depth_duplication is not None :
368379 self .rules_ = self .deduplicate (self .rules_ )
369380 return self
@@ -576,7 +587,7 @@ def recurse(node, base_name):
576587 else :
577588 rule = str .join (' and ' , base_name )
578589 rule = (rule if rule != ''
579- else '== ' .join ([feature_names [0 ]] * 2 ))
590+ else ' == ' .join ([feature_names [0 ]] * 2 ))
580591 # a rule selecting all is set to "c0==c0"
581592 rules .append (rule )
582593
0 commit comments