Skip to content

Commit 9b15266

Browse files
author
Cloud User
committed
Multi fix: call predict with DF vs np array for string issue, strip extra space on strings from SAS char to match python string, generate unique module name for python so MAS does not see conflicts
1 parent 79cbc9f commit 9b15266

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

src/sasctl/utils/pymas/core.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,25 @@ def build_wrapper_function(func, variables, array_input, setup=None,
5858
args = input_names
5959
func = func.__name__ if callable(func) else func
6060

61+
# HELPER: SAS to python char issue where SAS char have spaces and python string does not.
62+
# NOTE: we assume SAS char always need white space to be trimmed. This seems to match python model built so far
63+
pythonStringInput = ('',)
64+
for tmp1 in variables:
65+
if not tmp1.out:
66+
if tmp1.type == 'char':
67+
pythonStringInput = pythonStringInput + (" if " + tmp1.name + ": " + tmp1.name + " = " + tmp1.name + ".strip()",)
68+
6169
# Statement to execute the function w/ provided parameters
6270
if array_input:
63-
func_call = '{}(np.array([{}]).reshape((1, -1)))'.format(func, ','.join(args))
71+
middle = pythonStringInput +\
72+
(' inputarray = np.array([{}]).reshape((1, -1))'.format(','.join(args)),
73+
' column=[{}]'.format(','.join('"{0}"'.format(w) for w in args)),
74+
' import pandas as pd',
75+
' inputrun=pd.DataFrame(data=inputarray, columns=column)',
76+
' result = {}(inputrun)'.format(func))
6477
else:
6578
func_call = '{}({})'.format(func, ','.join(args))
79+
middle = (' result = {}'.format(func_call))
6680

6781
# TODO: Verify that # of values returned by wrapped func matches length of output_names
6882
# TODO: cast all return types before returning (DS2 errors out if not exact match)
@@ -80,6 +94,7 @@ def build_wrapper_function(func, variables, array_input, setup=None,
8094
else:
8195
header = ('', )
8296

97+
8398
definition = header +\
8499
('def wrapper({}):'.format(', '.join(args)),
85100
' "Output: {}"'.format(', '.join(output_names + ['msg']) if return_msg
@@ -90,9 +105,9 @@ def build_wrapper_function(func, variables, array_input, setup=None,
90105
' if _compile_error is not None:',
91106
' raise _compile_error',
92107
' msg = ""' if return_msg else '',
93-
' import numpy as np',
94-
' result = {}'.format(func_call),
95-
' if result.size == 1:',
108+
' import numpy as np') +\
109+
middle +\
110+
(' if result.size == 1:',
96111
' result = np.asscalar(result)',
97112
' except Exception as e:',
98113
' msg = str(e)' if return_msg else '',
@@ -103,6 +118,7 @@ def build_wrapper_function(func, variables, array_input, setup=None,
103118
' else: ',
104119
' return result, msg')
105120

121+
106122
return '\n'.join(definition)
107123

108124

src/sasctl/utils/pymas/ds2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(self, variables, code=None, return_code=True,
1818
self._python_code = code or []
1919
code = code or []
2020

21-
self.methods = [DS2PyMASMethod(variables, code, return_code,
21+
self.methods = [DS2PyMASMethod(self._id, variables, code, return_code,
2222
return_message, target)]
2323

2424
self._body = ("dcl package pymas py;",
@@ -88,7 +88,7 @@ def code(self):
8888

8989

9090
class DS2PyMASMethod(DS2BaseMethod):
91-
def __init__(self, variables, python_code, return_code=True,
91+
def __init__(self, name, variables, python_code, return_code=True,
9292
return_message=True, target='wrapper'):
9393

9494
target = target or 'wrapper'
@@ -110,12 +110,12 @@ def __init__(self, variables, python_code, return_code=True,
110110
self.private_variables] \
111111
+ ["if null(py) then do;",
112112
" py = _new_ pymas();",
113-
" rc = py.useModule('mypymodule', 1);",
113+
" rc = py.useModule('%s', 1);" % name,
114114
" if rc then do;"] \
115115
+ [" rc = py.appendSrcLine('%s');" % l for l in
116116
python_code] \
117117
+ [" pycode = py.getSource();",
118-
" revision = py.publish(pycode, 'mypymodule');",
118+
" revision = py.publish(pycode, '%s');" % name,
119119
" if revision lt 1 then do;",
120120
" logr.log('e', 'py.publish() failed.');",
121121
" rc = -1;",

0 commit comments

Comments
 (0)