Skip to content

Commit ea42699

Browse files
committed
Add handling for MLFlow model inputs/outputs
1 parent 9e4d842 commit ea42699

File tree

1 file changed

+61
-34
lines changed

1 file changed

+61
-34
lines changed

src/sasctl/pzmm/writeJSONFiles.py

Lines changed: 61 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
# %%
1717
class JSONFiles:
18-
def writeVarJSON(self, inputDF, isInput=True, jPath=Path.cwd()):
18+
def writeVarJSON(self, inputData, isInput=True, jPath=Path.cwd()):
1919
"""
2020
Writes a variable descriptor JSON file for input or output variables,
2121
based on an input dataframe containing predictor and prediction columns.
@@ -26,53 +26,80 @@ def writeVarJSON(self, inputDF, isInput=True, jPath=Path.cwd()):
2626
2727
Parameters
2828
---------------
29-
inputDF : Dataframe
29+
inputData : Dataframe or dict
3030
Input dataframe containing the training data set in a
3131
pandas.Dataframe format. Columns are used to define predictor and
32-
prediction variables (ambiguously named "predict").
32+
prediction variables (ambiguously named "predict"). Providing a dict
33+
object signals that the model files are being created from an MLFlow model.
3334
isInput : boolean
3435
Boolean to check if generating the input or output variable JSON.
3536
jPath : string, optional
3637
File location for the output JSON file. Default is the current
3738
working directory.
3839
"""
39-
40-
try:
41-
predictNames = inputDF.columns.values.tolist()
42-
isSeries = False
43-
except AttributeError:
44-
predictNames = [inputDF.name]
45-
isSeries = True
4640
outputJSON = pd.DataFrame()
41+
if isinstance(inputData, dict):
42+
predictNames = [var["name"] for var in inputData]
43+
for i, name in enumerate(predictNames):
44+
if inputData[i]["type"] == "string":
45+
isStr = True
46+
elif inputData[i]["type"] in ["double", "integer", "float", "long"]:
47+
isStr = False
48+
elif inputData[i]["type"] == "tensor":
49+
if inputData[i]["tensor-spec"]["dtype"] in "string":
50+
isStr = True
51+
else:
52+
isStr = False
4753

48-
# loop through all predict variables to determine their name, length,
49-
# type, and level; append each to outputJSON
50-
for name in predictNames:
51-
if isSeries:
52-
predict = inputDF
53-
else:
54-
predict = inputDF[name]
55-
firstRow = predict.loc[predict.first_valid_index()]
56-
dType = predict.dtypes.name
57-
isStr = type(firstRow) is str
58-
59-
if isStr:
60-
outputLevel = "nominal"
61-
outputType = "string"
62-
outputLength = predict.str.len().max()
63-
else:
64-
if dType == "category":
54+
if isStr:
6555
outputLevel = "nominal"
56+
outputType = "string"
57+
outputLength = 8
6658
else:
6759
outputLevel = "interval"
68-
outputType = "decimal"
69-
outputLength = 8
60+
outputType = "decimal"
61+
outputLength = 8
62+
outputRow = pd.Series(
63+
[name, outputLength, outputType, outputLevel],
64+
index=["name", "length", "type", "level"],
65+
)
66+
outputJSON = outputJSON.append([outputRow], ignore_index=True)
67+
else:
68+
try:
69+
predictNames = inputData.columns.values.tolist()
70+
isSeries = False
71+
except AttributeError:
72+
predictNames = [inputData.name]
73+
isSeries = True
74+
75+
# loop through all predict variables to determine their name, length,
76+
# type, and level; append each to outputJSON
77+
for name in predictNames:
78+
if isSeries:
79+
predict = inputData
80+
else:
81+
predict = inputData[name]
82+
firstRow = predict.loc[predict.first_valid_index()]
83+
dType = predict.dtypes.name
84+
isStr = type(firstRow) is str
7085

71-
outputRow = pd.Series(
72-
[name, outputLength, outputType, outputLevel],
73-
index=["name", "length", "type", "level"],
74-
)
75-
outputJSON = outputJSON.append([outputRow], ignore_index=True)
86+
if isStr:
87+
outputLevel = "nominal"
88+
outputType = "string"
89+
outputLength = predict.str.len().max()
90+
else:
91+
if dType == "category":
92+
outputLevel = "nominal"
93+
else:
94+
outputLevel = "interval"
95+
outputType = "decimal"
96+
outputLength = 8
97+
98+
outputRow = pd.Series(
99+
[name, outputLength, outputType, outputLevel],
100+
index=["name", "length", "type", "level"],
101+
)
102+
outputJSON = outputJSON.append([outputRow], ignore_index=True)
76103

77104
if isInput:
78105
fileName = "inputVar.json"

0 commit comments

Comments
 (0)