Skip to content

Commit 159767e

Browse files
author
Jon Walker
authored
Merge pull request #41 from jameskochubasas/master
python models with string modifications
2 parents f8d54d8 + ec358d5 commit 159767e

File tree

4 files changed

+35
-14
lines changed

4 files changed

+35
-14
lines changed

src/sasctl/utils/pymas/core.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,25 @@ def build_wrapper_function(func, variables, array_input, setup=None,
5858
args = input_names
5959
func = func.__name__ if callable(func) else func
6060

61+
# HELPER: SAS to python char issue where SAS char have spaces and python string does not.
62+
# NOTE: we assume SAS char always need white space to be trimmed. This seems to match python model built so far
63+
pythonStringInput = ('',)
64+
for tmp1 in variables:
65+
if not tmp1.out:
66+
if tmp1.type == 'char':
67+
pythonStringInput = pythonStringInput + (" if " + tmp1.name + ": " + tmp1.name + " = " + tmp1.name + ".strip()",)
68+
6169
# Statement to execute the function w/ provided parameters
6270
if array_input:
63-
func_call = '{}(np.array([{}]).reshape((1, -1)))'.format(func, ','.join(args))
71+
middle = pythonStringInput +\
72+
(' inputarray = np.array([{}]).reshape((1, -1))'.format(','.join(args)),
73+
' column=[{}]'.format(','.join('"{0}"'.format(w) for w in args)),
74+
' import pandas as pd',
75+
' inputrun=pd.DataFrame(data=inputarray, columns=column)',
76+
' result = {}(inputrun)'.format(func))
6477
else:
6578
func_call = '{}({})'.format(func, ','.join(args))
79+
middle = (' result = {}'.format(func_call),)
6680

6781
# TODO: Verify that # of values returned by wrapped func matches length of output_names
6882
# TODO: cast all return types before returning (DS2 errors out if not exact match)
@@ -80,6 +94,7 @@ def build_wrapper_function(func, variables, array_input, setup=None,
8094
else:
8195
header = ('', )
8296

97+
8398
definition = header +\
8499
('def wrapper({}):'.format(', '.join(args)),
85100
' "Output: {}"'.format(', '.join(output_names + ['msg']) if return_msg
@@ -90,9 +105,9 @@ def build_wrapper_function(func, variables, array_input, setup=None,
90105
' if _compile_error is not None:',
91106
' raise _compile_error',
92107
' msg = ""' if return_msg else '',
93-
' import numpy as np',
94-
' result = {}'.format(func_call),
95-
' if result.size == 1:',
108+
' import numpy as np') +\
109+
middle +\
110+
(' if result.size == 1:',
96111
' result = np.asscalar(result)',
97112
' except Exception as e:',
98113
' msg = str(e)' if return_msg else '',
@@ -103,6 +118,7 @@ def build_wrapper_function(func, variables, array_input, setup=None,
103118
' else: ',
104119
' return result, msg')
105120

121+
106122
return '\n'.join(definition)
107123

108124

src/sasctl/utils/pymas/ds2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(self, variables, code=None, return_code=True,
1818
self._python_code = code or []
1919
code = code or []
2020

21-
self.methods = [DS2PyMASMethod(variables, code, return_code,
21+
self.methods = [DS2PyMASMethod(self._id, variables, code, return_code,
2222
return_message, target)]
2323

2424
self._body = ("dcl package pymas py;",
@@ -88,7 +88,7 @@ def code(self):
8888

8989

9090
class DS2PyMASMethod(DS2BaseMethod):
91-
def __init__(self, variables, python_code, return_code=True,
91+
def __init__(self, name, variables, python_code, return_code=True,
9292
return_message=True, target='wrapper'):
9393

9494
target = target or 'wrapper'
@@ -110,12 +110,12 @@ def __init__(self, variables, python_code, return_code=True,
110110
self.private_variables] \
111111
+ ["if null(py) then do;",
112112
" py = _new_ pymas();",
113-
" rc = py.useModule('mypymodule', 1);",
113+
" rc = py.useModule('%s', 1);" % name,
114114
" if rc then do;"] \
115115
+ [" rc = py.appendSrcLine('%s');" % l for l in
116116
python_code] \
117117
+ [" pycode = py.getSource();",
118-
" revision = py.publish(pycode, 'mypymodule');",
118+
" revision = py.publish(pycode, '%s');" % name,
119119
" if revision lt 1 then do;",
120120
" logr.log('e', 'py.publish() failed.');",
121121
" rc = -1;",

tests/integration/test_pymas.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def test_from_pickle(train_data, pickle_file):
165165
166166
if null(py) then do;
167167
py = _new_ pymas();
168-
rc = py.useModule('mypymodule', 1);
168+
rc = py.useModule('DF74A4B18C9E41A2A34B0053E123AA67', 1);
169169
if rc then do;
170170
rc = py.appendSrcLine('try:');
171171
rc = py.appendSrcLine(' import pickle, base64');
@@ -184,7 +184,12 @@ def test_from_pickle(train_data, pickle_file):
184184
rc = py.appendSrcLine(' raise _compile_error');
185185
rc = py.appendSrcLine(' msg = ""');
186186
rc = py.appendSrcLine(' import numpy as np');
187-
rc = py.appendSrcLine(' result = obj.predict(np.array([SepalLength,SepalWidth,PetalLength,PetalWidth]).reshape((1, -1)))');
187+
rc = py.appendSrcLine('');
188+
rc = py.appendSrcLine(' inputarray = np.array([SepalLength,SepalWidth,PetalLength,PetalWidth]).reshape((1, -1))');
189+
rc = py.appendSrcLine(' column=["SepalLength","SepalWidth","PetalLength","PetalWidth"]');
190+
rc = py.appendSrcLine(' import pandas as pd');
191+
rc = py.appendSrcLine(' inputrun=pd.DataFrame(data=inputarray, columns=column)');
192+
rc = py.appendSrcLine(' result = obj.predict(inputrun)');
188193
rc = py.appendSrcLine(' if result.size == 1:');
189194
rc = py.appendSrcLine(' result = np.asscalar(result)');
190195
rc = py.appendSrcLine(' except Exception as e:');
@@ -196,7 +201,7 @@ def test_from_pickle(train_data, pickle_file):
196201
rc = py.appendSrcLine(' else: ');
197202
rc = py.appendSrcLine(' return result, msg');
198203
pycode = py.getSource();
199-
revision = py.publish(pycode, 'mypymodule');
204+
revision = py.publish(pycode, 'DF74A4B18C9E41A2A34B0053E123AA67');
200205
if revision lt 1 then do;
201206
logr.log('e', 'py.publish() failed.');
202207
rc = -1;

tests/unit/test_pymas.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def domath(a, b):
357357
d = py.getDouble('d');
358358
end;
359359
"""
360-
method = DS2PyMASMethod([
360+
method = DS2PyMASMethod('mypymodule',[
361361
DS2Variable('a', 'double', False),
362362
DS2Variable('b', 'double', False),
363363
DS2Variable('c', 'double', True),
@@ -452,15 +452,15 @@ def test_ds2_package():
452452
453453
if null(py) then do;
454454
py = _new_ pymas();
455-
rc = py.useModule('mypymodule', 1);
455+
rc = py.useModule('pyscore', 1);
456456
if rc then do;
457457
rc = py.appendSrcLine('def domath(a, b):');
458458
rc = py.appendSrcLine(' "Output: c, d"');
459459
rc = py.appendSrcLine(' c = a * b');
460460
rc = py.appendSrcLine(' d = a / b');
461461
rc = py.appendSrcLine(' return c, d');
462462
pycode = py.getSource();
463-
revision = py.publish(pycode, 'mypymodule');
463+
revision = py.publish(pycode, 'pyscore');
464464
if revision lt 1 then do;
465465
logr.log('e', 'py.publish() failed.');
466466
rc = -1;

0 commit comments

Comments
 (0)