Skip to content

Commit 7cc0855

Browse files
committed
Include float checks on h2o score outputs that are supposed to be numeric in type.
1 parent a40143c commit 7cc0855

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/sasctl/pzmm/write_score_code.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -917,7 +917,7 @@ def _no_targets_no_thresholds(
917917
cls.score_code += f"{'':4}{metrics[0]} = prediction[1][0]\n"
918918
for i in range(len(metrics) - 1):
919919
cls.score_code += (
920-
f"{'':4}{metrics[i + 1]} = prediction[1][{i + 1}]\n"
920+
f"{'':4}{metrics[i + 1]} = float(prediction[1][{i + 1}])\n"
921921
)
922922
else:
923923
for i in range(len(metrics)):
@@ -1018,7 +1018,8 @@ def _binary_target(
10181018
"score code should output the classification and probability for "
10191019
"the target event to occur."
10201020
)
1021-
cls.score_code += f"{'':4}return prediction[1][0], prediction[1][2]"
1021+
cls.score_code += f"{'':4}return prediction[1][0], " \
1022+
f"float(prediction[1][2])"
10221023
# Calculate the classification; return the classification and probability
10231024
elif sum(returns) == 0 and len(returns) == 1:
10241025
warn(
@@ -1069,7 +1070,8 @@ def _binary_target(
10691070
elif len(metrics) == 3:
10701071
if h2o_model:
10711072
cls.score_code += (
1072-
f"{'':4}return prediction[1][0], prediction[1][1], prediction[1][2]"
1073+
f"{'':4}return prediction[1][0], float(prediction[1][1]), "
1074+
f"float(prediction[1][2])"
10731075
)
10741076
elif sum(returns) == 0 and len(returns) == 1:
10751077
warn(

0 commit comments

Comments
 (0)