Skip to content

Commit b7920b6

Browse files
committed
Check all dictionaries for numpy values with the NpEncoder
1 parent 736135f commit b7920b6

File tree

1 file changed

+29
-27
lines changed

1 file changed

+29
-27
lines changed

src/sasctl/pzmm/write_json_files.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,28 @@
2323
from ..utils.decorators import deprecated
2424
from ..utils.misc import check_if_jupyter
2525

26+
try:
27+
# noinspection PyPackageRequirements
28+
import numpy as np
29+
30+
31+
class NpEncoder(json.JSONEncoder):
32+
def default(self, obj):
33+
if isinstance(obj, np.integer):
34+
return int(obj)
35+
if isinstance(obj, np.floating):
36+
return float(obj)
37+
if isinstance(obj, np.ndarray):
38+
return obj.tolist()
39+
return json.JSONEncoder.default(self, obj)
40+
41+
except ImportError:
42+
np = None
43+
44+
45+
class NpEncoder(json.JSONEncoder):
46+
pass
47+
2648
# TODO: add converter for any type of dataset (list, dataframe, numpy array)
2749

2850
# Constants
@@ -112,26 +134,6 @@ def write_var_json(
112134
Dictionary containing a key-value pair representing the file name and json
113135
dump respectively.
114136
"""
115-
try:
116-
# noinspection PyPackageRequirements
117-
import numpy as np
118-
119-
class NpEncoder(json.JSONEncoder):
120-
def default(self, obj):
121-
if isinstance(obj, np.integer):
122-
return int(obj)
123-
if isinstance(obj, np.floating):
124-
return float(obj)
125-
if isinstance(obj, np.ndarray):
126-
return obj.tolist()
127-
return json.JSONEncoder.default(self, obj)
128-
129-
except ImportError:
130-
np = None
131-
132-
class NpEncoder(json.JSONEncoder):
133-
pass
134-
135137
# MLFlow model handling
136138
if isinstance(input_data, list):
137139
dict_list = cls.generate_mlflow_variable_properties(input_data)
@@ -597,14 +599,14 @@ def input_fit_statistics(
597599

598600
if json_path:
599601
with open(Path(json_path) / FITSTAT, "w") as json_file:
600-
json_file.write(json.dumps(json_dict, indent=4))
602+
json_file.write(json.dumps(json_dict, indent=4, cls=NpEncoder))
601603
if cls.notebook_output:
602604
print(
603605
f"{FITSTAT} was successfully written and saved to "
604606
f"{Path(json_path) / FITSTAT}"
605607
)
606608
else:
607-
return {FITSTAT: json.dumps(json_dict, indent=4)}
609+
return {FITSTAT: json.dumps(json_dict, indent=4, cls=NpEncoder)}
608610

609611
@classmethod
610612
def add_tuple_to_fitstat(
@@ -881,17 +883,17 @@ def calculate_model_statistics(
881883
if json_path:
882884
for name in [FITSTAT, ROC, LIFT]:
883885
with open(Path(json_path) / name, "w") as json_file:
884-
json_file.write(json.dumps(json_dict, indent=4))
886+
json_file.write(json.dumps(json_dict, indent=4, cls=NpEncoder))
885887
if cls.notebook_output:
886888
print(
887889
f"{name} was successfully written and saved to "
888890
f"{Path(json_path) / name}"
889891
)
890892
else:
891893
return {
892-
FITSTAT: json.dumps(json_dict[0], indent=4),
893-
ROC: json.dumps(json_dict[1], indent=4),
894-
LIFT: json.dumps(json_dict[2], indent=4),
894+
FITSTAT: json.dumps(json_dict[0], indent=4, cls=NpEncoder),
895+
ROC: json.dumps(json_dict[1], indent=4, cls=NpEncoder),
896+
LIFT: json.dumps(json_dict[2], indent=4, cls=NpEncoder),
895897
}
896898

897899
@staticmethod
@@ -1044,7 +1046,7 @@ def apply_dataframe_to_json(
10441046
json_dict[row_num + partition * len(stat_df)]["dataMap"].update(row_dict)
10451047
return json_dict
10461048

1047-
# noinspection PyCallingNonCallable,PyNestedDecorators
1049+
# noinspection PyCallingNonCallable, PyNestedDecorators
10481050
@deprecated(
10491051
"Please use the calculate_model_statistics method instead.",
10501052
version="1.9",

0 commit comments

Comments
 (0)