Skip to content

Commit 33d9afd

Browse files
committed
return code instead of writing to file
1 parent 8275663 commit 33d9afd

File tree

2 files changed

+25
-21
lines changed

2 files changed

+25
-21
lines changed

src/sasctl/utils/pyml2ds/core.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import io
12
import os
23
import pickle
34
import xml.etree.ElementTree as etree
@@ -43,7 +44,7 @@ def _check_type(model):
4344

4445

4546
@experimental
46-
def pyml2ds(in_file, out_file, out_var_name="P_TARGET"):
47+
def pyml2ds(in_file, out_var_name="P_TARGET"):
4748
"""Translate a gradient boosting model and write SAS scoring code to file.
4849
4950
Supported models are: xgboost, lightgbm and pmml gradient boosting.
@@ -54,11 +55,14 @@ def pyml2ds(in_file, out_file, out_var_name="P_TARGET"):
5455
Pickled object to translate. String is assumed to be a path to a picked
5556
file, file-like is assumed to be an open file handle to a pickle
5657
object, and bytes is assumed to be the raw pickled bytes.
57-
out_file : str
58-
Path to output file with SAS code.
5958
out_var_name : str (optional)
6059
Output variable name.
6160
61+
Returns
62+
-------
63+
str
64+
A SAS Data Step program implementing the model.
65+
6266
"""
6367

6468
try:
@@ -88,5 +92,12 @@ def pyml2ds(in_file, out_file, out_var_name="P_TARGET"):
8892
# Verify model is a valid type
8993
parser = _check_type(model)
9094
parser.out_var_name = out_var_name
91-
with open(out_file, "w") as f:
92-
parser.translate(f)
95+
96+
# Parser is currently written to expect a file input
97+
# Until refactored, use StringIO to collect the text in memory
98+
with io.StringIO() as f:
99+
parser.translate(f)
100+
101+
# Return contents of "file"
102+
f.seek(0)
103+
return f.read()

tests/unit/test_pyml2ds.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717
DATA_PATH = os.path.join(dirname(dirname(__file__)), 'pyml2ds_data')
1818

1919

20-
def test_xgb2ds(tmpdir):
20+
def test_xgb2ds():
2121
pytest.importorskip('xgboost')
2222

2323
IN_PKL = os.path.join(DATA_PATH, 'xgb.pkl')
24-
OUT_SAS = os.path.join(str(tmpdir), 'xgb.sas')
2524
EXPECTED_SAS = os.path.join(DATA_PATH, 'xgb.sas')
2625

2726
from sasctl.utils.pyml2ds.connectors.ensembles.xgb import XgbTreeParser
@@ -41,18 +40,16 @@ def _leaf_value(self):
4140

4241
with mock.patch('sasctl.utils.pyml2ds.connectors.ensembles.xgb.XgbTreeParser') as parser:
4342
parser.return_value = test_parser
44-
pyml2ds(IN_PKL, OUT_SAS)
43+
result = pyml2ds(IN_PKL)
4544

46-
result = open(OUT_SAS, 'rb').read()
47-
expected = open(EXPECTED_SAS, 'rb').read()
45+
expected = open(EXPECTED_SAS, 'r').read()
4846
assert result == expected
4947

5048

51-
def test_lgb2ds(tmpdir):
49+
def test_lgb2ds():
5250
pytest.importorskip('lightgbm')
5351

5452
IN_PKL = os.path.join(DATA_PATH, 'lgb.pkl')
55-
OUT_SAS = os.path.join(str(tmpdir), 'lgb.sas')
5653
EXPECTED_SAS = os.path.join(DATA_PATH, 'lgb.sas')
5754

5855
from sasctl.utils.pyml2ds.connectors.ensembles.lgb import LightgbmTreeParser
@@ -72,17 +69,14 @@ def _leaf_value(self):
7269

7370
with mock.patch('sasctl.utils.pyml2ds.connectors.ensembles.lgb.LightgbmTreeParser') as parser:
7471
parser.return_value = test_parser
75-
pyml2ds(IN_PKL, OUT_SAS)
72+
result = pyml2ds(IN_PKL)
7673

77-
78-
result = open(OUT_SAS, 'rb').read()
79-
expected = open(EXPECTED_SAS, 'rb').read()
74+
expected = open(EXPECTED_SAS, 'r').read()
8075
assert result == expected
8176

8277

83-
def test_gbm2ds(tmpdir):
78+
def test_gbm2ds():
8479
IN_PKL = os.path.join(DATA_PATH, 'gbm.pmml')
85-
OUT_SAS = os.path.join(str(tmpdir), 'gbm.sas')
8680
EXPECTED_SAS = os.path.join(DATA_PATH, 'gbm.sas')
8781

8882
from sasctl.utils.pyml2ds.connectors.ensembles.pmml import PmmlTreeParser
@@ -102,10 +96,9 @@ def _leaf_value(self):
10296

10397
with mock.patch('sasctl.utils.pyml2ds.connectors.ensembles.pmml.PmmlTreeParser') as parser:
10498
parser.return_value = test_parser
105-
pyml2ds(IN_PKL, OUT_SAS)
99+
result = pyml2ds(IN_PKL)
106100

107-
result = open(OUT_SAS, 'rb').read()
108-
expected = open(EXPECTED_SAS, 'rb').read()
101+
expected = open(EXPECTED_SAS, 'r').read()
109102
assert result == expected
110103

111104

0 commit comments

Comments
 (0)