Skip to content

Commit 158d3a4

Browse files
committed
removed test= parameter
1 parent 250d7d0 commit 158d3a4

File tree

4 files changed

+69
-16
lines changed

4 files changed

+69
-16
lines changed

src/sasctl/utils/pyml2ds/basic/tree.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def _remove_diacritic(self, input):
9797

9898
return output
9999

100-
def parse_node(self, f, node=None, test=False):
100+
def parse_node(self, f, node=None):
101101
"""Recursively parse tree node and write generated SAS code to file.
102102
103103
Attributes
@@ -108,8 +108,6 @@ def parse_node(self, f, node=None, test=False):
108108
Tree node to process.
109109
110110
"""
111-
if test:
112-
self.test = True
113111

114112
pnode = self._node
115113
if node is not None:
@@ -128,8 +126,6 @@ def parse_node(self, f, node=None, test=False):
128126
var = self._remove_diacritic(var)
129127

130128
split_value = self._split_value()
131-
if self.test:
132-
split_value = int(float(split_value))
133129

134130
cond = ""
135131
if self._go_left():
@@ -151,8 +147,6 @@ def parse_node(self, f, node=None, test=False):
151147
f.write(self._get_indent() + "end;\n")
152148
else:
153149
leaf_value = self._leaf_value()
154-
if self.test:
155-
leaf_value = int(float(leaf_value))
156150

157151
f.write(self._get_indent() + "treeValue%s = %s;\n"
158152
% (self._tree_id, leaf_value))

src/sasctl/utils/pyml2ds/connectors/ensembles/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def _aggregate(self, booster_count):
2929
return "treeValue = sum({});\n".format(', '.join(
3030
["treeValue%d" % i for i in range(booster_count)]))
3131

32-
def translate(self, f, test=False):
32+
def translate(self, f):
3333
"""Translate a gradient boosting model and write SAS scoring code to a file.
3434
3535
Attributes
@@ -42,7 +42,7 @@ def translate(self, f, test=False):
4242
f.write("/* Parsing tree {}*/\n".format(booster_id))
4343

4444
self._tree_parser.init(tree, booster_id)
45-
self._tree_parser.parse_node(f, test=test)
45+
self._tree_parser.parse_node(f)
4646

4747
f.write("\n")
4848

src/sasctl/utils/pyml2ds/pyml2ds.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ def _check_type(model):
3838
else:
3939
raise RuntimeError("Unknown booster type: %s. Compatible types are: %s."
4040
" Check if corresponding library is installed."
41-
% type(model).__name__)
41+
% (type(model).__name__, ', '.join(comp_types)))
4242

4343
return parser
4444

4545

46-
def pyml2ds(in_file, out_file, out_var_name="P_TARGET", test=False):
46+
def pyml2ds(in_file, out_file, out_var_name="P_TARGET"):
4747
"""Translate a gradient boosting model and write SAS scoring code to file.
4848
4949
Supported models are: xgboost, lightgbm and pmml gradient boosting.
@@ -69,4 +69,4 @@ def pyml2ds(in_file, out_file, out_var_name="P_TARGET", test=False):
6969
parser = _check_type(model)
7070
parser.out_var_name = out_var_name
7171
with open(out_file, "w") as f:
72-
parser.translate(f, test=test)
72+
parser.translate(f)

tests/unit/test_pyml2ds.py

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66

77

88
import os
9-
dirname = os.path.dirname
109

1110
import pytest
11+
from six.moves import mock
12+
1213
from sasctl.utils.pyml2ds import pyml2ds
1314

1415

16+
dirname = os.path.dirname
1517
DATA_PATH = os.path.join(dirname(dirname(__file__)), 'pyml2ds_data')
1618

1719

@@ -22,7 +24,25 @@ def test_xgb2ds(tmpdir):
2224
OUT_SAS = os.path.join(str(tmpdir), 'xgb.sas')
2325
EXPECTED_SAS = os.path.join(DATA_PATH, 'xgb.sas')
2426

25-
pyml2ds(IN_PKL, OUT_SAS, test=True)
27+
from sasctl.utils.pyml2ds.connectors.ensembles.xgb import XgbTreeParser
28+
29+
# Expected output contains integer values instead of floats.
30+
# Convert to ensure match.
31+
class TestXgbTreeParser(XgbTreeParser):
32+
def _split_value(self):
33+
val = super(TestXgbTreeParser, self)._split_value()
34+
return int(float(val))
35+
36+
def _leaf_value(self):
37+
val = super(TestXgbTreeParser, self)._leaf_value()
38+
return int(float(val))
39+
40+
test_parser = TestXgbTreeParser()
41+
42+
with mock.patch('sasctl.utils.pyml2ds.connectors.ensembles.xgb.XgbTreeParser') as parser:
43+
parser.return_value = test_parser
44+
pyml2ds(IN_PKL, OUT_SAS)
45+
2646
result = open(OUT_SAS, 'rb').read()
2747
expected = open(EXPECTED_SAS, 'rb').read()
2848
assert result == expected
@@ -35,7 +55,26 @@ def test_lgb2ds(tmpdir):
3555
OUT_SAS = os.path.join(str(tmpdir), 'lgb.sas')
3656
EXPECTED_SAS = os.path.join(DATA_PATH, 'lgb.sas')
3757

38-
pyml2ds(IN_PKL, OUT_SAS, test=True)
58+
from sasctl.utils.pyml2ds.connectors.ensembles.lgb import LightgbmTreeParser
59+
60+
# Expected output contains integer values instead of floats.
61+
# Convert to ensure match.
62+
class TestLightgbmTreeParser(LightgbmTreeParser):
63+
def _split_value(self):
64+
val = super(TestLightgbmTreeParser, self)._split_value()
65+
return int(float(val))
66+
67+
def _leaf_value(self):
68+
val = super(TestLightgbmTreeParser, self)._leaf_value()
69+
return int(float(val))
70+
71+
test_parser = TestLightgbmTreeParser()
72+
73+
with mock.patch('sasctl.utils.pyml2ds.connectors.ensembles.lgb.LightgbmTreeParser') as parser:
74+
parser.return_value = test_parser
75+
pyml2ds(IN_PKL, OUT_SAS)
76+
77+
3978
result = open(OUT_SAS, 'rb').read()
4079
expected = open(EXPECTED_SAS, 'rb').read()
4180
assert result == expected
@@ -46,7 +85,27 @@ def test_gbm2ds(tmpdir):
4685
OUT_SAS = os.path.join(str(tmpdir), 'gbm.sas')
4786
EXPECTED_SAS = os.path.join(DATA_PATH, 'gbm.sas')
4887

49-
pyml2ds(IN_PKL, OUT_SAS, test=True)
88+
from sasctl.utils.pyml2ds.connectors.ensembles.pmml import PmmlTreeParser
89+
90+
# Expected output contains integer values instead of floats.
91+
# Convert to ensure match.
92+
class TestPmmlTreeParser(PmmlTreeParser):
93+
def _split_value(self):
94+
val = super(TestPmmlTreeParser, self)._split_value()
95+
return int(float(val))
96+
97+
def _leaf_value(self):
98+
val = super(TestPmmlTreeParser, self)._leaf_value()
99+
return int(float(val))
100+
101+
test_parser = TestPmmlTreeParser()
102+
103+
with mock.patch('sasctl.utils.pyml2ds.connectors.ensembles.pmml.PmmlTreeParser') as parser:
104+
parser.return_value = test_parser
105+
pyml2ds(IN_PKL, OUT_SAS)
106+
50107
result = open(OUT_SAS, 'rb').read()
51108
expected = open(EXPECTED_SAS, 'rb').read()
52109
assert result == expected
110+
111+

0 commit comments

Comments
 (0)