Skip to content

Commit 22d31bb

Browse files
committed
Include json encoder to swap out numpy types for their Python counterparts. #159
1 parent 6e1fa49 commit 22d31bb

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

src/sasctl/pzmm/write_json_files.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,25 @@ def write_var_json(
112112
Dictionary containing a key-value pair representing the file name and json
113113
dump respectively.
114114
"""
115+
try:
116+
# noinspection PyPackageRequirements
117+
import numpy as np
118+
119+
class NpEncoder(json.JSONEncoder):
120+
def default(self, obj):
121+
if isinstance(obj, np.integer):
122+
return int(obj)
123+
if isinstance(obj, np.floating):
124+
return float(obj)
125+
if isinstance(obj, np.ndarray):
126+
return obj.tolist()
127+
return json.JSONEncoder.default(self, obj)
128+
except ImportError:
129+
np = None
130+
131+
class NpEncoder(json.JSONEncoder):
132+
pass
133+
115134
# MLFlow model handling
116135
if isinstance(input_data, list):
117136
dict_list = cls.generate_mlflow_variable_properties(input_data)
@@ -125,17 +144,17 @@ def write_var_json(
125144
file_name = OUTPUT
126145

127146
with open(Path(json_path) / file_name, "w") as json_file:
128-
json_file.write(json.dumps(dict_list, indent=4))
147+
json_file.write(json.dumps(dict_list, indent=4, cls=NpEncoder))
129148
if cls.notebook_output:
130149
print(
131150
f"{file_name} was successfully written and saved to "
132151
f"{Path(json_path) / file_name}"
133152
)
134153
else:
135154
if is_input:
136-
return {INPUT: json.dumps(dict_list)}
155+
return {INPUT: json.dumps(dict_list, indent=4, cls=NpEncoder)}
137156
else:
138-
return {OUTPUT: json.dumps(dict_list)}
157+
return {OUTPUT: json.dumps(dict_list, indent=4, cls=NpEncoder)}
139158

140159
@staticmethod
141160
def generate_variable_properties(

0 commit comments

Comments
 (0)