Skip to content

Commit f1785c5

Browse files
Fixed tests
1 parent 6ea222e commit f1785c5

File tree

9 files changed

+6250
-6224
lines changed

9 files changed

+6250
-6224
lines changed

src/sasctl/utils/pyml2ds/basic/tree.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import abc, six
2+
import sys
3+
import unicodedata
24

35
@six.add_metaclass(abc.ABCMeta)
46
class TreeParser:
@@ -93,6 +95,17 @@ def _leaf_value(self):
9395
pass
9496

9597

98+
def _remove_diacritic(self, input):
99+
if sys.hexversion >= 0x3000000:
100+
output = unicodedata.normalize('NFKD', input).encode('ASCII', 'ignore').decode()
101+
else:
102+
# On Python < 3.0.0
103+
if type(input) == str:
104+
input = unicode(input, 'ISO-8859-1')
105+
output = unicodedata.normalize('NFKD', input).encode('ASCII', 'ignore')
106+
107+
return output
108+
96109

97110
"""Recursively parses tree node and writes generated SAS code to file.
98111
@@ -103,7 +116,10 @@ def _leaf_value(self):
103116
node: node
104117
Tree node to process.
105118
"""
106-
def parse_node(self, f, node=None):
119+
def parse_node(self, f, node=None, test=False):
120+
if test:
121+
self.test = True
122+
107123
pnode = self._node
108124
if node is not None:
109125
self._node = node
@@ -118,6 +134,12 @@ def parse_node(self, f, node=None):
118134

119135
if self._not_leaf():
120136
var = self._get_var()[:32]
137+
var = self._remove_diacritic(var)
138+
139+
split_value = self._split_value()
140+
if self.test:
141+
split_value = int(float(split_value))
142+
121143
cond = ""
122144
if self._go_left():
123145
cond = "missing({}) or ".format(var)
@@ -128,14 +150,18 @@ def parse_node(self, f, node=None):
128150
self.parse_node(f, node=self._missing_node())
129151
f.write(self._get_indent() + "end;\n")
130152

131-
f.write(self._get_indent() + "if ({}{} {} {}) then do;\n".format(cond, var, self._decision_type(), self._split_value()))
153+
f.write(self._get_indent() + "if ({}{} {} {}) then do;\n".format(cond, var, self._decision_type(), split_value))
132154
self.parse_node(f, node=self._left_node())
133155
f.write(self._get_indent() + "end;\n")
134156
f.write(self._get_indent() + "else do;\n")
135157
self.parse_node(f, node=self._right_node())
136158
f.write(self._get_indent() + "end;\n")
137159
else:
138-
f.write(self._get_indent() + "treeValue{} = {};\n".format(self._tree_id, self._leaf_value()))
160+
leaf_value = self._leaf_value()
161+
if self.test:
162+
leaf_value = int(float(leaf_value))
163+
164+
f.write(self._get_indent() + "treeValue{} = {};\n".format(self._tree_id, leaf_value))
139165

140166
self._node = pnode
141167
self._depth -= 1

src/sasctl/utils/pyml2ds/connectors/ensembles/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@ def _aggregate(self, booster_count):
3333
f : file object
3434
Open file for writing output SAS code.
3535
"""
36-
def translate(self, f):
36+
def translate(self, f, test=False):
3737
for booster_id, tree in self._iter_trees():
3838
f.write("/* Parsing tree {}*/\n".format(booster_id))
3939

4040
self._tree_parser.init(tree, booster_id)
41-
self._tree_parser.parse_node(f)
41+
self._tree_parser.parse_node(f, test=test)
4242

4343
f.write("\n")
4444

src/sasctl/utils/pyml2ds/pyml2ds.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def _check_type(model):
5656
outVarName : str (optional)
5757
Output variable name.
5858
"""
59-
def pyml2ds(inFile, outFile, outVarName="P_TARGET"):
59+
def pyml2ds(inFile, outFile, outVarName="P_TARGET", test=False):
6060
# Load model file
6161
ext = ".pmml"
6262
if inFile[-len(ext):] == ext:
@@ -70,4 +70,4 @@ def pyml2ds(inFile, outFile, outVarName="P_TARGET"):
7070
parser = _check_type(model)
7171
parser.out_var_name = outVarName
7272
with open(outFile, "w") as f:
73-
parser.translate(f)
73+
parser.translate(f, test=test)

tests/pyml2ds_data/gbm.sas

Lines changed: 4936 additions & 4936 deletions
Large diffs are not rendered by default.

tests/pyml2ds_data/lgb.pkl

48 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)