Skip to content

Commit 3fa1965

Browse files
committed
fix bug on get weight
Signed-off-by: Sunyanan Choochotkaew <[email protected]>
1 parent 10e0742 commit 3fa1965

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

src/train/trainer/scikit.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from sklearn.metrics import mean_absolute_error
2-
2+
import numpy as np
33
import os
44
import sys
55

@@ -68,13 +68,21 @@ def get_weight_dict(self, node_type):
6868

6969
for component, model in self.node_models[node_type].items():
7070
scaler = self.node_scalers[node_type]
71-
if not hasattr(model, "intercept_") or not hasattr(model, "coef_") or len(model.coef_) != len(self.features) or len(model.intercept_) != 1:
71+
if not hasattr(model, "intercept_") or not hasattr(model, "coef_") or len(model.coef_) != len(self.features) or (hasattr(model.intercept_, "__len__") and len(model.intercept_) != 1):
7272
return None
7373
else:
74+
if isinstance(model.intercept_, np.float64):
75+
intercept = model.intercept_
76+
elif hasattr(model.intercept_, "__len__"):
77+
intercept = model.intercept_[0]
78+
else:
79+
# no valid intercept
80+
return None
81+
7482
# TODO: remove the mean and variance variables after updating the Kepler code
7583
weight_dict[component] = {
7684
"All_Weights": {
77-
"Bias_Weight": model.intercept_[0],
85+
"Bias_Weight": intercept,
7886
"Categorical_Variables": dict(),
7987
"Numerical_Variables": {self.features[i]:
8088
{"scale": scaler.scale_[i],

0 commit comments

Comments
 (0)