Skip to content

Commit f504104

Browse files
committed
pep8 fixes
1 parent a3eaf9a commit f504104

File tree

7 files changed

+119
-130
lines changed

7 files changed

+119
-130
lines changed
Lines changed: 38 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
import abc, six
1+
import abc
22
import sys
33
import unicodedata
4+
from builtins import str
5+
6+
import six
7+
48

59
@six.add_metaclass(abc.ABCMeta)
610
class TreeParser:
@@ -9,114 +13,102 @@ class TreeParser:
913
Attributes
1014
----------
1115
d : dict
12-
Dictionary for storing node hierarchy. Used not in all models.
16+
Dictionary for storing node hierarchy. Not used in all models.
1317
out_transform : string
14-
Output transformation for generated value. For example, for logreg is used: 1 / (1 + exp(-{0})), where {0} stands for resulting gbvalue.
18+
Output transformation for generated value. For example, if logreg is
19+
used: 1 / (1 + exp(-{0})), where {0} stands for resulting gbvalue.
1520
out_var_name : string
1621
Name used for output variable.
22+
1723
"""
1824
def __init__(self):
1925
self.d = {}
2026

27+
def init(self, root, tree_id=0):
28+
"""Custom init method. Need to be called before using TreeParser.
2129
22-
"""Custom init method. Need to be called before using TreeParser.
30+
Attributes
31+
----------
32+
root : node
33+
Tree root node.
34+
tree_id : int
35+
Id of current tree.
2336
24-
Attributes
25-
----------
26-
root : node
27-
Tree root node.
28-
tree_id : int
29-
Id of current tree.
30-
"""
31-
def init(self, root, tree_id=0):
37+
"""
3238
self._root = root
3339
self._node = root
3440
self._tree_id = tree_id
35-
3641
self._depth = -1
3742
self._indent = 4
3843

39-
4044
def _gen_dict(self):
4145
self.d = {}
4246

43-
4447
def _get_indent(self):
4548
return " " * self._indent * self._depth
4649

47-
4850
@abc.abstractmethod
4951
def _not_leaf(self):
5052
pass
5153

52-
5354
@abc.abstractmethod
5455
def _get_var(self):
5556
pass
5657

57-
5858
@abc.abstractmethod
5959
def _go_left(self):
6060
pass
6161

62-
6362
@abc.abstractmethod
6463
def _go_right(self):
6564
pass
6665

67-
6866
@abc.abstractmethod
6967
def _left_node(self):
7068
pass
7169

72-
7370
@abc.abstractmethod
7471
def _right_node(self):
7572
pass
7673

77-
7874
@abc.abstractmethod
7975
def _missing_node(self):
8076
pass
8177

82-
8378
@abc.abstractmethod
8479
def _split_value(self):
8580
pass
8681

87-
8882
@abc.abstractmethod
8983
def _decision_type(self):
9084
pass
9185

92-
9386
@abc.abstractmethod
9487
def _leaf_value(self):
9588
pass
9689

97-
9890
def _remove_diacritic(self, input):
9991
if sys.hexversion >= 0x3000000:
10092
output = unicodedata.normalize('NFKD', input).encode('ASCII', 'ignore').decode()
10193
else:
10294
# On Python < 3.0.0
10395
if type(input) == str:
104-
input = unicode(input, 'ISO-8859-1')
96+
input = str(input, 'ISO-8859-1')
10597
output = unicodedata.normalize('NFKD', input).encode('ASCII', 'ignore')
106-
98+
10799
return output
108100

101+
def parse_node(self, f, node=None, test=False):
102+
"""Recursively parse tree node and write generated SAS code to file.
109103
110-
"""Recursively parses tree node and writes generated SAS code to file.
104+
Attributes
105+
----------
106+
f : file object
107+
Open file for writing output SAS code.
108+
node: node
109+
Tree node to process.
111110
112-
Attributes
113-
----------
114-
f : file object
115-
Open file for writing output SAS code.
116-
node: node
117-
Tree node to process.
118-
"""
119-
def parse_node(self, f, node=None, test=False):
111+
"""
120112
if test:
121113
self.test = True
122114

@@ -135,7 +127,7 @@ def parse_node(self, f, node=None, test=False):
135127
if self._not_leaf():
136128
var = self._get_var()[:32]
137129
var = self._remove_diacritic(var)
138-
130+
139131
split_value = self._split_value()
140132
if self.test:
141133
split_value = int(float(split_value))
@@ -146,11 +138,13 @@ def parse_node(self, f, node=None, test=False):
146138
elif self._go_right():
147139
cond = "not missing({}) and ".format(var)
148140
else:
149-
f.write(self._get_indent() + "if (missing({})) then do;\n".format(var))
141+
f.write(self._get_indent() + "if (missing(%s)) then do;\n"
142+
% var)
150143
self.parse_node(f, node=self._missing_node())
151144
f.write(self._get_indent() + "end;\n")
152-
153-
f.write(self._get_indent() + "if ({}{} {} {}) then do;\n".format(cond, var, self._decision_type(), split_value))
145+
146+
f.write(self._get_indent() + "if ({}{} {} {}) then do;\n".format(
147+
cond, var, self._decision_type(), split_value))
154148
self.parse_node(f, node=self._left_node())
155149
f.write(self._get_indent() + "end;\n")
156150
f.write(self._get_indent() + "else do;\n")
@@ -161,7 +155,8 @@ def parse_node(self, f, node=None, test=False):
161155
if self.test:
162156
leaf_value = int(float(leaf_value))
163157

164-
f.write(self._get_indent() + "treeValue{} = {};\n".format(self._tree_id, leaf_value))
165-
158+
f.write(self._get_indent() + "treeValue%s = %s;\n"
159+
% (self._tree_id, leaf_value))
160+
166161
self._node = pnode
167162
self._depth -= 1
Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,52 @@
1-
import abc, six
1+
import abc
2+
3+
import six
24

35

46
@six.add_metaclass(abc.ABCMeta)
57
class EnsembleParser:
6-
"""Abstract class for parsing decision tree ensebmles.
8+
"""Abstract class for parsing decision tree ensembles.
79
810
Attributes
911
----------
1012
out_transform : string
11-
Output transformation for generated value. For example, for logreg is used: 1 / (1 + exp(-{0})), where {0} stands for resulting gbvalue.
13+
Output transformation for generated value. For example, if logreg is
14+
used: 1 / (1 + exp(-{0})), where {0} stands for resulting gbvalue.
1215
out_var_name : string
1316
Name used for output variable.
17+
1418
"""
19+
1520
def __init__(self, out_transform="{0}", out_var_name="P_TARGET"):
1621
self.out_transform = out_transform
1722
self.out_var_name = out_var_name
1823

19-
2024
@abc.abstractmethod
2125
def _iter_trees(self):
2226
pass
2327

24-
2528
def _aggregate(self, booster_count):
26-
return "treeValue = sum({});\n".format(', '.join(["treeValue%d" % i for i in range(booster_count)]))
29+
return "treeValue = sum({});\n".format(', '.join(
30+
["treeValue%d" % i for i in range(booster_count)]))
2731

32+
def translate(self, f, test=False):
33+
"""Translate a gradient boosting model and write SAS scoring code to a file.
2834
29-
"""Translates gradient boosting model and writes SAS scoring code to file.
35+
Attributes
36+
----------
37+
f : file object
38+
Open file for writing output SAS code.
3039
31-
Attributes
32-
----------
33-
f : file object
34-
Open file for writing output SAS code.
35-
"""
36-
def translate(self, f, test=False):
40+
"""
3741
for booster_id, tree in self._iter_trees():
3842
f.write("/* Parsing tree {}*/\n".format(booster_id))
39-
43+
4044
self._tree_parser.init(tree, booster_id)
4145
self._tree_parser.parse_node(f, test=test)
4246

4347
f.write("\n")
44-
48+
4549
f.write("/* Getting target probability */\n")
4650
f.write(self._aggregate(booster_id + 1))
47-
f.write("{} = {};\n".format(self.out_var_name, self.out_transform.format("treeValue")))
51+
f.write("{} = {};\n".format(self.out_var_name,
52+
self.out_transform.format("treeValue")))

src/sasctl/utils/pyml2ds/connectors/ensembles/lgb.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,39 +9,30 @@ class LightgbmTreeParser(TreeParser):
99
def _not_leaf(self):
1010
return 'split_feature' in self._node
1111

12-
1312
def _get_var(self):
1413
return self._features[self._node['split_feature']]
1514

16-
1715
def _go_left(self):
1816
return self._node['default_left']
1917

20-
2118
def _go_right(self):
2219
return (not self._node['default_left'])
2320

24-
2521
def _left_node(self):
2622
return self._node['left_child']
2723

28-
2924
def _right_node(self):
3025
return self._node['right_child']
3126

32-
3327
def _missing_node(self):
3428
return None
3529

36-
3730
def _split_value(self):
3831
return self._node['threshold']
3932

40-
4133
def _decision_type(self):
4234
return self._node['decision_type']
4335

44-
4536
def _leaf_value(self):
4637
return self._node['leaf_value']
4738

@@ -53,23 +44,25 @@ class LightgbmParser(EnsembleParser):
5344
----------
5445
booster : lightgbm.basic.Booster
5546
Booster of lightgbm model.
47+
5648
"""
5749
def __init__(self, booster):
5850
super(LightgbmParser, self).__init__()
5951

6052
self._booster = booster
61-
6253
self._dump = booster.dump_model()
54+
6355
if self._dump['objective'] != 'binary sigmoid:1':
64-
raise Exception("Unfortunately only binary sigmoid objective function is supported right now. Your objective is %s. Please, open an issue at https://gitlab.sas.com/from-russia-with-love/lgb2sas." % self.dump['objective'])
56+
raise ValueError("Only binary sigmoid objective function is "
57+
"currently supported. Received '%s'."
58+
% self.dump['objective'])
6559

6660
self._features = self._dump['feature_names']
6761
self.out_transform = "1 / (1 + exp(-{0}))"
6862

6963
self._tree_parser = LightgbmTreeParser()
7064
self._tree_parser._features = self._features
7165

72-
7366
def _iter_trees(self):
7467
for tree in self._dump['tree_info']:
7568
yield tree['tree_index'], tree['tree_structure']

0 commit comments

Comments
 (0)