Skip to content

Commit c13ab9d

Browse files
committed
support pickle and file inputs
1 parent 158d3a4 commit c13ab9d

File tree

3 files changed

+103
-15
lines changed

3 files changed

+103
-15
lines changed

src/sasctl/utils/pyml2ds/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .pyml2ds import pyml2ds
1+
from .core import pyml2ds

src/sasctl/utils/pyml2ds/pyml2ds.py renamed to src/sasctl/utils/pyml2ds/core.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1+
import os
2+
import pickle
13
import xml.etree.ElementTree as etree
24

3-
try:
4-
import pickle
5-
except ImportError:
6-
pickle = None
7-
5+
import six
86
try:
97
import xgboost
108
except ImportError:
@@ -50,23 +48,43 @@ def pyml2ds(in_file, out_file, out_var_name="P_TARGET"):
5048
5149
Parameters
5250
----------
53-
in_file : str
54-
Path to file to be translated.
51+
in_file : str or bytes or file-like
52+
Pickled object to translate. String is assumed to be a path to a picked
53+
file, file-like is assumed to be an open file handle to a pickle
54+
object, and bytes is assumed to be the raw pickled bytes.
5555
out_file : str
5656
Path to output file with SAS code.
5757
out_var_name : str (optional)
5858
Output variable name.
5959
6060
"""
61-
# Load model file
62-
ext = ".pmml"
63-
if in_file[-len(ext):] == ext:
64-
model = etree.parse(in_file)
61+
62+
try:
63+
# In Python2 str could either be a path or the binary pickle data,
64+
# so check if its a valid filepath too.
65+
is_file_path = isinstance(in_file, six.string_types) and os.path.isfile(in_file)
66+
except TypeError:
67+
is_file_path = False
68+
69+
# Path to a PMML or pickle file
70+
if is_file_path:
71+
# Parse PMML files
72+
if os.path.splitext(in_file)[-1] == '.pmml':
73+
model = etree.parse(in_file)
74+
else:
75+
# Read pickled files
76+
with open(in_file, 'rb') as f:
77+
model = pickle.load(f)
78+
79+
elif isinstance(in_file, bytes):
80+
# Assume byte string is the actual pickled bytes
81+
model = pickle.loads(in_file)
6582
else:
66-
with open(in_file, 'rb') as mf:
67-
model = pickle.load(mf)
83+
# Assume a file object containing the pickled object
84+
model = pickle.load(in_file)
6885

86+
# Verify model is a valid type
6987
parser = _check_type(model)
7088
parser.out_var_name = out_var_name
7189
with open(out_file, "w") as f:
72-
parser.translate(f)
90+
parser.translate(f)

tests/unit/test_pyml2ds.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,73 @@ def _leaf_value(self):
109109
assert result == expected
110110

111111

112+
def test_path_input(tmpdir_factory):
113+
"""pyml2ds should accept a file path (str) as input."""
114+
import pickle
115+
from sasctl.utils.pyml2ds import pyml2ds
116+
117+
# The target "model" to use
118+
target = {'msg': 'hello world'}
119+
120+
# Pickle the "model" to a file
121+
temp_dir = tmpdir_factory.mktemp('pyml2ds')
122+
in_file = str(temp_dir.join('model.pkl'))
123+
out_file = str(temp_dir.join('model.sas'))
124+
with open(in_file, 'wb') as f:
125+
pickle.dump(target, f)
126+
127+
with mock.patch('sasctl.utils.pyml2ds.core._check_type') as check:
128+
check.translate.return_value = 'translated'
129+
pyml2ds(in_file, out_file)
130+
131+
# Verify _check_type should have been called with the "model"
132+
assert check.call_count == 1
133+
assert check.call_args[0][0] == target
134+
135+
136+
def test_file_input():
137+
"""pyml2ds should accept a file-like obj as input."""
138+
import io
139+
import pickle
140+
from sasctl.utils.pyml2ds import pyml2ds
141+
142+
# The target "model" to use
143+
target = {'msg': 'hello world'}
144+
145+
# Pickle the "model" to a file-like object
146+
in_file = io.BytesIO(pickle.dumps(target))
147+
out_file = 'model.sas'
148+
149+
with mock.patch('sasctl.utils.pyml2ds.core._check_type') as check:
150+
check.translate.return_value = 'translated'
151+
pyml2ds(in_file, out_file)
152+
153+
# Verify _check_type should have been called with the "model"
154+
assert check.call_count == 1
155+
assert check.call_args[0][0] == target
156+
157+
def test_pickle_input():
158+
"""pyml2ds should accept a binary pickle string as input."""
159+
import pickle
160+
from sasctl.utils.pyml2ds import pyml2ds
161+
162+
# The target "model" to use
163+
target = {'msg': 'hello world'}
164+
165+
# Pickle the "model" to a file-like object
166+
in_file = pickle.dumps(target)
167+
out_file = 'model.sas'
168+
169+
with mock.patch('sasctl.utils.pyml2ds.core._check_type') as check:
170+
check.translate.return_value = 'translated'
171+
pyml2ds(in_file, out_file)
172+
173+
# Verify _check_type should have been called with the "model"
174+
assert check.call_count == 1
175+
assert check.call_args[0][0] == target
176+
177+
178+
179+
180+
181+

0 commit comments

Comments
 (0)