Skip to content

Commit 512cf21

Browse files
committed
better handling of input types
1 parent 0a191ca commit 512cf21

File tree

4 files changed

+37
-7
lines changed

4 files changed

+37
-7
lines changed

src/sasctl/_services/microanalytic_score.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import re
1010
from collections import OrderedDict
1111

12+
import six
13+
1214
from .service import Service
1315

1416

@@ -116,7 +118,18 @@ def execute_module_step(self, module, step, return_dict=True, **kwargs):
116118
module = module.id
117119
step = step.id if hasattr(step, 'id') else step
118120

119-
body = {'inputs': [{'name': k, 'value': v} for k, v in kwargs.items()]}
121+
# Make sure all inputs are JSON serializable
122+
# Common types such as numpy.int64 and numpy.float64 are NOT serializable
123+
for k in kwargs.keys():
124+
type_name = type(kwargs[k]).__name__
125+
if type_name == 'float64':
126+
kwargs[k] = float(kwargs[k])
127+
elif type_name == 'int64':
128+
kwargs[k] = int(kwargs[k])
129+
130+
131+
body = {'inputs': [{'name': k, 'value': v}
132+
for k, v in six.iteritems(kwargs)]}
120133
r = self.post('/modules/{}/steps/{}'.format(module, step), json=body)
121134

122135
# Convert list of name/value pair dictionaries to single dict

src/sasctl/tasks.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,13 @@ def register_model(model, name, project, repository=None, input=None,
8888
repository : str or dict, optional
8989
The name or id of the repository, or a dictionary representation of
9090
the repository. If omitted, the default repository will be used.
91-
input
91+
input : DataFrame, type, list of type, or dict of str: type, optional
92+
The expected type for each input value of the target function.
93+
Can be omitted if target function includes type hints. If a DataFrame
94+
is provided, the columns will be inspected to determine type information.
95+
If a single type is provided, all columns will be assumed to be that type,
96+
otherwise a list of column types or a dictionary of column_name: type
97+
may be provided.
9298
version : {'new', 'latest', int}, optional
9399
Version number of the project in which the model should be created.
94100
Defaults to 'new'.

src/sasctl/utils/pymas/core.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,13 @@ def from_pickle(file, func_name=None, input_types=None, array_input=False,
200200
object, and bytes is assumed to be the raw pickled bytes.
201201
func_name : str
202202
Name of the target function to call
203-
input_types : list of type, optional
203+
input_types : DataFrame, type, list of type, or dict of str: type, optional
204204
The expected type for each input value of the target function.
205-
Can be ommitted if target function includes type hints.
205+
Can be omitted if target function includes type hints. If a DataFrame
206+
is provided, the columns will be inspected to determine type information.
207+
If a single type is provided, all columns will be assumed to be that type,
208+
otherwise a list of column types or a dictionary of column_name: type
209+
may be provided.
206210
array_input : bool
207211
Whether the function inputs should be treated as an array instead of
208212
individual parameters
@@ -271,7 +275,8 @@ def _build_pymas(obj, func_name=None, input_types=None, array_input=False,
271275

272276
# Run one observation through the model and use the result to
273277
# determine output variables
274-
output = target_func(input_types.iloc[0, :].values.reshape((1, -1)))
278+
output = target_func(input_types.head(1))
279+
# output = target_func(input_types.iloc[0, :].values.reshape((1, -1)))
275280
output_vars = ds2_variables(output, output_vars=True)
276281
vars.extend(output_vars)
277282
elif isinstance(input_types, type):

src/sasctl/utils/pymas/python.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,14 @@ def ds2_variables(input, output_vars=False):
5151
types = input
5252
elif hasattr(input, 'columns') and hasattr(input, 'dtypes'):
5353
# Pandas DataFrame
54-
types = OrderedDict([(col, (input[col].dtype.name.replace('object', 'char'), False)) for col in input.columns])
55-
# types = {col: (input[col].dtype.name.replace('object', 'char'), False) for col in input.columns}
54+
types = OrderedDict()
55+
for col in input.columns:
56+
if input[col].dtype.name == 'object':
57+
types[col] = ('char', False)
58+
elif input[col].dtype.name == 'category':
59+
types[col] = ('char', False)
60+
else:
61+
types[col] = (input[col].dtype.name, False)
5662
elif hasattr(input, 'dtype'):
5763
# Numpy array? No column names, but we can at least create dummy vars of the correct type
5864
types = OrderedDict([('var{}'.format(i),

0 commit comments

Comments
 (0)