Skip to content

Commit a3eaf9a

Browse files
author
Jon Walker
authored
Merge pull request #29 from InvalidPointer/master
Added pyml2ds module
2 parents a54a360 + bc40918 commit a3eaf9a

File tree

19 files changed

+34474
-0
lines changed

19 files changed

+34474
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ dist
3232
*.egg-info
3333
.eggs
3434
.pypirc
35+
*.pyc
3536

3637
## tox testing tool
3738
.tox

src/sasctl/utils/pyml2ds/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .pyml2ds import pyml2ds
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .tree import TreeParser
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import abc, six
2+
import sys
3+
import unicodedata
4+
5+
@six.add_metaclass(abc.ABCMeta)
6+
class TreeParser:
7+
"""Abstract class for parsing decision tree.
8+
9+
Attributes
10+
----------
11+
d : dict
12+
Dictionary for storing node hierarchy. Used not in all models.
13+
out_transform : string
14+
Output transformation for generated value. For example, for logreg is used: 1 / (1 + exp(-{0})), where {0} stands for resulting gbvalue.
15+
out_var_name : string
16+
Name used for output variable.
17+
"""
18+
def __init__(self):
19+
self.d = {}
20+
21+
22+
"""Custom init method. Need to be called before using TreeParser.
23+
24+
Attributes
25+
----------
26+
root : node
27+
Tree root node.
28+
tree_id : int
29+
Id of current tree.
30+
"""
31+
def init(self, root, tree_id=0):
32+
self._root = root
33+
self._node = root
34+
self._tree_id = tree_id
35+
36+
self._depth = -1
37+
self._indent = 4
38+
39+
40+
def _gen_dict(self):
41+
self.d = {}
42+
43+
44+
def _get_indent(self):
45+
return " " * self._indent * self._depth
46+
47+
48+
@abc.abstractmethod
49+
def _not_leaf(self):
50+
pass
51+
52+
53+
@abc.abstractmethod
54+
def _get_var(self):
55+
pass
56+
57+
58+
@abc.abstractmethod
59+
def _go_left(self):
60+
pass
61+
62+
63+
@abc.abstractmethod
64+
def _go_right(self):
65+
pass
66+
67+
68+
@abc.abstractmethod
69+
def _left_node(self):
70+
pass
71+
72+
73+
@abc.abstractmethod
74+
def _right_node(self):
75+
pass
76+
77+
78+
@abc.abstractmethod
79+
def _missing_node(self):
80+
pass
81+
82+
83+
@abc.abstractmethod
84+
def _split_value(self):
85+
pass
86+
87+
88+
@abc.abstractmethod
89+
def _decision_type(self):
90+
pass
91+
92+
93+
@abc.abstractmethod
94+
def _leaf_value(self):
95+
pass
96+
97+
98+
def _remove_diacritic(self, input):
99+
if sys.hexversion >= 0x3000000:
100+
output = unicodedata.normalize('NFKD', input).encode('ASCII', 'ignore').decode()
101+
else:
102+
# On Python < 3.0.0
103+
if type(input) == str:
104+
input = unicode(input, 'ISO-8859-1')
105+
output = unicodedata.normalize('NFKD', input).encode('ASCII', 'ignore')
106+
107+
return output
108+
109+
110+
"""Recursively parses tree node and writes generated SAS code to file.
111+
112+
Attributes
113+
----------
114+
f : file object
115+
Open file for writing output SAS code.
116+
node: node
117+
Tree node to process.
118+
"""
119+
def parse_node(self, f, node=None, test=False):
120+
if test:
121+
self.test = True
122+
123+
pnode = self._node
124+
if node is not None:
125+
self._node = node
126+
else:
127+
self._node = self._root
128+
129+
if self._node == self._root:
130+
self.d = dict()
131+
self._gen_dict()
132+
133+
self._depth += 1
134+
135+
if self._not_leaf():
136+
var = self._get_var()[:32]
137+
var = self._remove_diacritic(var)
138+
139+
split_value = self._split_value()
140+
if self.test:
141+
split_value = int(float(split_value))
142+
143+
cond = ""
144+
if self._go_left():
145+
cond = "missing({}) or ".format(var)
146+
elif self._go_right():
147+
cond = "not missing({}) and ".format(var)
148+
else:
149+
f.write(self._get_indent() + "if (missing({})) then do;\n".format(var))
150+
self.parse_node(f, node=self._missing_node())
151+
f.write(self._get_indent() + "end;\n")
152+
153+
f.write(self._get_indent() + "if ({}{} {} {}) then do;\n".format(cond, var, self._decision_type(), split_value))
154+
self.parse_node(f, node=self._left_node())
155+
f.write(self._get_indent() + "end;\n")
156+
f.write(self._get_indent() + "else do;\n")
157+
self.parse_node(f, node=self._right_node())
158+
f.write(self._get_indent() + "end;\n")
159+
else:
160+
leaf_value = self._leaf_value()
161+
if self.test:
162+
leaf_value = int(float(leaf_value))
163+
164+
f.write(self._get_indent() + "treeValue{} = {};\n".format(self._tree_id, leaf_value))
165+
166+
self._node = pnode
167+
self._depth -= 1
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .ensembles import XgbParser, LightgbmParser, PmmlParser
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .xgb import XgbParser
2+
from .lgb import LightgbmParser
3+
from .pmml import PmmlParser
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import abc, six
2+
3+
4+
@six.add_metaclass(abc.ABCMeta)
5+
class EnsembleParser:
6+
"""Abstract class for parsing decision tree ensebmles.
7+
8+
Attributes
9+
----------
10+
out_transform : string
11+
Output transformation for generated value. For example, for logreg is used: 1 / (1 + exp(-{0})), where {0} stands for resulting gbvalue.
12+
out_var_name : string
13+
Name used for output variable.
14+
"""
15+
def __init__(self, out_transform="{0}", out_var_name="P_TARGET"):
16+
self.out_transform = out_transform
17+
self.out_var_name = out_var_name
18+
19+
20+
@abc.abstractmethod
21+
def _iter_trees(self):
22+
pass
23+
24+
25+
def _aggregate(self, booster_count):
26+
return "treeValue = sum({});\n".format(', '.join(["treeValue%d" % i for i in range(booster_count)]))
27+
28+
29+
"""Translates gradient boosting model and writes SAS scoring code to file.
30+
31+
Attributes
32+
----------
33+
f : file object
34+
Open file for writing output SAS code.
35+
"""
36+
def translate(self, f, test=False):
37+
for booster_id, tree in self._iter_trees():
38+
f.write("/* Parsing tree {}*/\n".format(booster_id))
39+
40+
self._tree_parser.init(tree, booster_id)
41+
self._tree_parser.parse_node(f, test=test)
42+
43+
f.write("\n")
44+
45+
f.write("/* Getting target probability */\n")
46+
f.write(self._aggregate(booster_id + 1))
47+
f.write("{} = {};\n".format(self.out_var_name, self.out_transform.format("treeValue")))
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from sasctl.utils.pyml2ds.basic import TreeParser
2+
from .core import EnsembleParser
3+
4+
5+
class LightgbmTreeParser(TreeParser):
6+
"""Class for parsing lightgbm tree.
7+
"""
8+
9+
def _not_leaf(self):
10+
return 'split_feature' in self._node
11+
12+
13+
def _get_var(self):
14+
return self._features[self._node['split_feature']]
15+
16+
17+
def _go_left(self):
18+
return self._node['default_left']
19+
20+
21+
def _go_right(self):
22+
return (not self._node['default_left'])
23+
24+
25+
def _left_node(self):
26+
return self._node['left_child']
27+
28+
29+
def _right_node(self):
30+
return self._node['right_child']
31+
32+
33+
def _missing_node(self):
34+
return None
35+
36+
37+
def _split_value(self):
38+
return self._node['threshold']
39+
40+
41+
def _decision_type(self):
42+
return self._node['decision_type']
43+
44+
45+
def _leaf_value(self):
46+
return self._node['leaf_value']
47+
48+
49+
class LightgbmParser(EnsembleParser):
50+
"""Class for parsing lightgbm model.
51+
52+
Parameters
53+
----------
54+
booster : lightgbm.basic.Booster
55+
Booster of lightgbm model.
56+
"""
57+
def __init__(self, booster):
58+
super(LightgbmParser, self).__init__()
59+
60+
self._booster = booster
61+
62+
self._dump = booster.dump_model()
63+
if self._dump['objective'] != 'binary sigmoid:1':
64+
raise Exception("Unfortunately only binary sigmoid objective function is supported right now. Your objective is %s. Please, open an issue at https://gitlab.sas.com/from-russia-with-love/lgb2sas." % self.dump['objective'])
65+
66+
self._features = self._dump['feature_names']
67+
self.out_transform = "1 / (1 + exp(-{0}))"
68+
69+
self._tree_parser = LightgbmTreeParser()
70+
self._tree_parser._features = self._features
71+
72+
73+
def _iter_trees(self):
74+
for tree in self._dump['tree_info']:
75+
yield tree['tree_index'], tree['tree_structure']
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from sasctl.utils.pyml2ds.basic import TreeParser
2+
from .core import EnsembleParser
3+
4+
5+
class PmmlTreeParser(TreeParser):
6+
"""Class for parsing pmml gradient boosting tree.
7+
"""
8+
9+
def _not_leaf(self):
10+
return self._node.get('defaultChild')
11+
12+
13+
def _get_var(self):
14+
return self._node.find('Node').find('SimplePredicate').get('field')
15+
16+
17+
def _go_left(self):
18+
return (self._node.find('Node').get('id') == self._node.get('defaultChild'))
19+
20+
21+
def _go_right(self):
22+
return (not self._node.find('Node').get('id') == self._node.get('defaultChild'))
23+
24+
25+
def _left_node(self):
26+
return self._node.findall('Node')[0]
27+
28+
29+
def _right_node(self):
30+
return self._node.findall('Node')[1]
31+
32+
33+
def _missing_node(self):
34+
return None
35+
36+
37+
def _split_value(self):
38+
return self._node.find('Node').find('SimplePredicate').get('value')
39+
40+
41+
def _decision_type(self):
42+
ops = {'lessThan': '<', 'lessOrEqual': '<=', 'greaterThan': '>', 'greaterOrEqual': '>='}
43+
return ops[self._node.find('Node').find('SimplePredicate').get('operator')]
44+
45+
46+
def _leaf_value(self):
47+
return self._node.get('score')
48+
49+
50+
class PmmlParser(EnsembleParser):
51+
"""Class for parsing pmml gradient boosting models.
52+
53+
Parameters
54+
----------
55+
tree_root : etree.Element
56+
Root node of pmml gradient boosting forest.
57+
"""
58+
def __init__(self, tree_root):
59+
super(PmmlParser, self).__init__()
60+
61+
self._tree_root = tree_root
62+
for elem in tree_root.getiterator():
63+
if not hasattr(elem.tag, 'find'): continue
64+
i = elem.tag.find('}')
65+
if i >= 0:
66+
elem.tag = elem.tag[i+1:]
67+
68+
self._forest = tree_root.find('MiningModel/Segmentation')[0].find('MiningModel')
69+
70+
rescaleConstant = self._forest.find('Targets/Target').get('rescaleConstant')
71+
self.out_transform = "1 / (1 + exp(-{}))".format("({0} + " + "{})".format(rescaleConstant))
72+
73+
self._tree_parser = PmmlTreeParser()
74+
75+
76+
def _iter_trees(self):
77+
for booster_id, tree_elem in enumerate(self._forest.find('Segmentation')):
78+
yield booster_id, tree_elem.find('TreeModel/Node')

0 commit comments

Comments
 (0)