Skip to content

Commit 96cedb6

Browse files
author
Jon Walker
authored
Merge pull request #7 from jlwalke2/feature-pymas
sklearn model publishing
2 parents 6996e71 + e2f8c19 commit 96cedb6

File tree

8 files changed

+1175
-297
lines changed

8 files changed

+1175
-297
lines changed

src/sasctl/tasks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def register_model(model, name, project, repository=None, input=None, version='l
131131

132132
# Generate PyMAS wrapper
133133
try:
134-
mas_module = from_pickle(model_pkl, 'predict', input_types=input)
134+
mas_module = from_pickle(model_pkl, 'predict', input_types=input, array_input=True)
135135
assert isinstance(mas_module, PyMAS)
136136

137137
# Include score code files from ESP and MAS

src/sasctl/utils/pymas/core.py

Lines changed: 93 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -8,34 +8,35 @@
88
import base64
99
from collections import OrderedDict
1010
import importlib
11+
import pickle
1112
import os
1213
import sys
1314

14-
try:
15-
import dill
16-
from dill import load, loads, dumps, dump
17-
except ImportError:
18-
dill = None
19-
from pickle import load, loads, dumps, dump
15+
2016
import six
2117

22-
from .ds2 import DS2Method, DS2Thread, DS2Variable, DS2Package
18+
from .ds2 import DS2Thread, DS2Variable, DS2Package
2319
from .python import ds2_variables
2420

2521

26-
def build_wrapper_function(func, variables, array_input, return_msg=True):
22+
def build_wrapper_function(func, variables, array_input, setup=None,
23+
return_msg=True):
2724
"""Wraps a function to ensure compatibility when called by PyMAS.
2825
29-
PyMAS has strict expectations regarding the format of any function called directly by PyMAS.
30-
Isolating the desired function inside a wrapping function provides a simple way to ensure that functions
31-
called by PyMAS are compliant.
26+
PyMAS has strict expectations regarding the format of any function called
27+
directly by PyMAS. Isolating the desired function inside a wrapping
28+
function provides a simple way to ensure that functions called by PyMAS
29+
are compliant.
3230
3331
Parameters
3432
----------
3533
func : function or str
3634
Function name or an instance of Function which will be wrapped
3735
variables : list of DS2Variable
3836
array_input : bool
37+
Whether `variables` should be combined into a single array before passing to `func`
38+
setup : iterable
39+
Python source code lines to be executed during package setup
3940
return_msg : bool
4041
4142
Returns
@@ -45,36 +46,54 @@ def build_wrapper_function(func, variables, array_input, return_msg=True):
4546
4647
Notes
4748
-----
48-
The format for the `# Output: ` is very strict. It must be exactly "# Output: <var>, <var>". Any changes to
49-
spelling, capitalization, punctuation, or spacing will result in an error when the DS2 code is executed.
49+
The format for the `# Output: ` is very strict. It must be exactly
50+
"# Output: <var>, <var>". Any changes to spelling, capitalization,
51+
punctuation, or spacing will result in an error when the DS2 code is
52+
executed.
5053
5154
"""
5255

5356
input_names = [v.name for v in variables if not v.out]
5457
output_names = [v.name for v in variables if v.out]
55-
5658
args = input_names
57-
5859
func = func.__name__ if callable(func) else func
5960

6061
# Statement to execute the function w/ provided parameters
6162
if array_input:
62-
func_call = '{}(np.asarray({}).reshape((1,-1)))'.format(func, ','.join(args))
63+
func_call = '{}(np.array([{}]).reshape((1, -1)))'.format(func, ','.join(args))
6364
else:
6465
func_call = '{}({})'.format(func, ','.join(args))
6566

6667
# TODO: Verify that # of values returned by wrapped func matches length of output_names
6768
# TODO: cast all return types before returning (DS2 errors out if not exact match)
6869

69-
# NOTE: 'Output:' section is required. All return variables must be listed separated by ', '
70-
definition = ('def wrapper({}):'.format(', '.join(args)),
70+
# NOTE: 'Output:' section is required. All return variables must be listed
71+
# separated by ', '
72+
73+
if setup:
74+
header = ('try:', ) + \
75+
tuple(' ' + line for line in setup) + \
76+
(' _compile_error = None',
77+
'except Exception as e:',
78+
' _compile_error = e',
79+
'')
80+
else:
81+
header = ('', )
82+
83+
definition = header +\
84+
('def wrapper({}):'.format(', '.join(args)),
7185
' "Output: {}"'.format(', '.join(output_names + ['msg']) if return_msg
7286
else ', '.join(output_names)),
7387
' result = None',
7488
' try:',
89+
' global _compile_error',
90+
' if _compile_error is not None:',
91+
' raise _compile_error',
7592
' msg = ""' if return_msg else '',
7693
' import numpy as np',
77-
' result = float({})'.format(func_call),
94+
' result = {}'.format(func_call),
95+
' if result.size == 1:',
96+
' result = np.asscalar(result)',
7897
' except Exception as e:',
7998
' msg = str(e)' if return_msg else '',
8099
' if result is None:',
@@ -110,12 +129,14 @@ def from_inline(func, input_types=None, array_input=False, return_code=True, ret
110129
111130
"""
112131

113-
obj = dumps(func)
132+
obj = pickle.dumps(func)
114133
return from_pickle(obj, None, input_types, array_input, return_code, return_message)
115134

116135

117-
def from_python_file(file, func_name=None, input_types=None, array_input=False, return_code=True, return_message=True):
118-
""" Creates a PyMAS wrapper to execute a function defined in an external .py file.
136+
def from_python_file(file, func_name=None, input_types=None, array_input=False,
137+
return_code=True, return_message=True):
138+
"""Creates a PyMAS wrapper to execute a function defined in an
139+
external .py file.
119140
120141
Parameters
121142
----------
@@ -127,7 +148,8 @@ def from_python_file(file, func_name=None, input_types=None, array_input=False,
127148
The expected type for each input value of the target function.
128149
Can be ommitted if target function includes type hints.
129150
array_input : bool
130-
Whether the function inputs should be treated as an array instead of individual parameters
151+
Whether the function inputs should be treated as an array instead of
152+
individual parameters
131153
return_code : bool
132154
Whether the DS2-generated return code should be included
133155
return_message : bool
@@ -156,29 +178,34 @@ def from_python_file(file, func_name=None, input_types=None, array_input=False,
156178
target_func = getattr(module, func_name)
157179

158180
if not callable(target_func):
159-
raise RuntimeError("Could not find a valid function named {}".format(func_name))
181+
raise RuntimeError("Could not find a valid function named %s"
182+
% func_name)
160183

161184
with open(file, 'r') as f:
162185
code = [line.strip('\n') for line in f.readlines()]
163186

164-
return _build_pymas(target_func, None, input_types, array_input, return_code, return_message, code)
187+
return _build_pymas(target_func, None, input_types, array_input,
188+
return_code, return_message, code)
165189

166190

167-
def from_pickle(file, func_name=None, input_types=None, array_input=False, return_code=True, return_message=True):
191+
def from_pickle(file, func_name=None, input_types=None, array_input=False,
192+
return_code=True, return_message=True):
168193
"""Create a deployable DS2 package from a Python pickle file.
169194
170195
Parameters
171196
----------
172197
file : str or bytes or file_like
173-
Pickled object to use. String is assumed to be a path to a picked file, file_like is assumed to be an open
174-
file handle to a pickle object, and bytes is assumed to be the raw pickled bytes.
198+
Pickled object to use. String is assumed to be a path to a picked
199+
file, file_like is assumed to be an open file handle to a pickle
200+
object, and bytes is assumed to be the raw pickled bytes.
175201
func_name : str
176202
Name of the target function to call
177203
input_types : list of type, optional
178204
The expected type for each input value of the target function.
179205
Can be ommitted if target function includes type hints.
180206
array_input : bool
181-
Whether the function inputs should be treated as an array instead of individual parameters
207+
Whether the function inputs should be treated as an array instead of
208+
individual parameters
182209
return_code : bool
183210
Whether the DS2-generated return code should be included
184211
return_message : bool
@@ -190,39 +217,41 @@ def from_pickle(file, func_name=None, input_types=None, array_input=False, retur
190217
Generated DS2 code which can be executed in a SAS scoring environment
191218
192219
"""
193-
194220
try:
195-
# In Python2 str could either be a path or the binary pickle data, so check if its a valid filepath too.
221+
# In Python2 str could either be a path or the binary pickle data,
222+
# so check if its a valid filepath too.
196223
is_file_path = isinstance(file, six.string_types) and os.path.isfile(file)
197224
except TypeError:
198225
is_file_path = False
199226

200227
# Path to a pickle file
201228
if is_file_path:
202229
with open(file, 'rb') as f:
203-
obj = load(f)
230+
obj = pickle.load(f)
204231

205232
# The actual pickled bytes
206233
elif isinstance(file, bytes):
207-
obj = loads(file)
234+
obj = pickle.loads(file)
208235
else:
209-
obj = load(file)
236+
obj = pickle.load(file)
210237

211238
# Encode the pickled data so we can inline it in the DS2 package
212-
pkl = base64.b64encode(dumps(obj))
239+
pkl = base64.b64encode(pickle.dumps(obj))
213240

214-
package = 'dill' if dill else 'pickle'
241+
code = ('import pickle, base64',
242+
# Replace b' with " before embedding in DS2.
243+
'bytes = {}'.format(pkl).replace("'", '"'),
244+
'obj = pickle.loads(base64.b64decode(bytes))')
215245

216-
code = ('import %s, base64' % package,
217-
'bytes = {}'.format(pkl).replace("'", '"'), # Replace b' with " before embedding in DS2.
218-
'obj = %s.loads(base64.b64decode(bytes))' % package)
246+
return _build_pymas(obj, func_name, input_types, array_input, return_code,
247+
return_message, code)
219248

220-
return _build_pymas(obj, func_name, input_types, array_input, return_code, return_message, code)
221249

250+
def _build_pymas(obj, func_name=None, input_types=None, array_input=False,
251+
return_code=True, return_message=True, code=[]):
222252

223-
def _build_pymas(obj, func_name=None, input_types=None, array_input=False, return_code=True, return_message=True, code=[]):
224-
225-
# If the object passed was a function, no need to search for target function
253+
# If the object passed was a function, no need to search for
254+
# target function
226255
if six.callable(obj) and (func_name is None or obj.__name__ == func_name):
227256
target_func = obj
228257
elif func_name is None:
@@ -231,19 +260,23 @@ def _build_pymas(obj, func_name=None, input_types=None, array_input=False, retur
231260
target_func = getattr(obj, func_name)
232261

233262
if not callable(target_func):
234-
raise RuntimeError("Could not find a valid function named {}".format(func_name))
263+
raise RuntimeError("Could not find a valid function named %s"
264+
% func_name)
235265

236266
# Need to create DS2Variable instances to pass to PyMAS
237267
if hasattr(input_types, 'columns'):
238-
# Assuming input is a DataFrame representing model inputs. Use to get input variables
268+
# Assuming input is a DataFrame representing model inputs. Use to
269+
# get input variables
239270
vars = ds2_variables(input_types)
240271

241-
# Run one observation through the model and use the result to determine output variables
272+
# Run one observation through the model and use the result to
273+
# determine output variables
242274
output = target_func(input_types.iloc[0, :].values.reshape((1, -1)))
243275
output_vars = ds2_variables(output, output_vars=True)
244276
vars.extend(output_vars)
245277
elif isinstance(input_types, type):
246-
params = OrderedDict([(k, input_types) for k in target_func.__code__.co_varnames])
278+
params = OrderedDict([(k, input_types)
279+
for k in target_func.__code__.co_varnames])
247280
vars = ds2_variables(params)
248281
elif isinstance(input_types, dict):
249282
vars = ds2_variables(input_types)
@@ -253,22 +286,16 @@ def _build_pymas(obj, func_name=None, input_types=None, array_input=False, retur
253286

254287
target_func = 'obj.' + target_func.__name__
255288

256-
# If all inputs should be passed as an array
257-
if array_input:
258-
first_input = vars[0]
259-
array_type = first_input.type or 'double'
260-
out_vars = [x for x in vars if x.out]
261-
num_inputs = len(vars) - len(out_vars)
262-
vars = [DS2Variable(first_input.name, array_type + '[{}]'.format(num_inputs), out=False)] + out_vars
263-
264289
if not any([v for v in vars if v.out]):
265290
vars.append(DS2Variable(name='result', type='float', out=True))
266291

267-
return PyMAS(target_func, vars, code, return_code, return_message)
292+
return PyMAS(target_func, vars, code, return_code, return_message,
293+
array_input=array_input)
268294

269295

270296
class PyMAS:
271-
def __init__(self, target_function, variables, python_source, return_code=True, return_msg=True):
297+
def __init__(self, target_function, variables, python_source,
298+
return_code=True, return_msg=True, **kwargs):
272299
"""
273300
274301
Parameters
@@ -282,27 +309,30 @@ def __init__(self, target_function, variables, python_source, return_code=True,
282309
Whether the DS2-generated return code should be included
283310
return_msg : bool
284311
Whether the DS2-generated return message should be included
312+
kwargs : any
313+
Passed to :func:`build_wrapper_function`
285314
286315
"""
287316

288317
self.target = target_function
289318

290319
# Any input variable that should be treated as an array
291-
array_input = any(v for v in variables if v.is_array)
320+
# array_input = any(v for v in variables if v.is_array)
292321

293322
# Python wrapper function will serve as entrypoint from DS2
294323
self.wrapper = build_wrapper_function(target_function, variables,
295-
array_input, return_msg=return_msg).split('\n')
324+
setup=python_source,
325+
return_msg=return_msg,
326+
**kwargs).split('\n')
296327

297328
# Lines of Python code to be embedded in DS2
298-
python_source = list(python_source) + list(self.wrapper)
329+
python_source = list(self.wrapper)
299330

300331
self.variables = variables
301332
self.return_code = return_code
302333
self.return_message = return_msg
303334

304-
self.package = DS2Package()
305-
self.package.methods.append(DS2Method(variables, python_source))
335+
self.package = DS2Package(variables, python_source, return_code, return_msg)
306336

307337
def score_code(self, input_table=None, output_table=None, columns=None, dest='MAS'):
308338
"""Generate DS2 score code
@@ -333,7 +363,7 @@ def score_code(self, input_table=None, output_table=None, columns=None, dest='MA
333363
raise ValueError('Output table name `{}` is a reserved term.'.format(output_table))
334364

335365
# Get package code
336-
code = (str(self.package), )
366+
code = tuple(self.package.code().split('\n'))
337367

338368
if dest == 'ESP':
339369
code = ('data sasep.out;', ) + code + (' method run();',

0 commit comments

Comments
 (0)