Skip to content

Commit c85a0cf

Browse files
InvalidPointerInvalidPointer
authored andcommitted
Refactored, added tests
1 parent dd9bf6c commit c85a0cf

File tree

14 files changed

+33944
-102
lines changed

14 files changed

+33944
-102
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .ensembles import XgbParser, LightgbmParser, PmmlParser, ForestParser
1+
from .ensembles import XgbParser, LightgbmParser, PmmlParser
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
from .xgb import XgbParser
22
from .lgb import LightgbmParser
3-
from .pmml import PmmlParser
4-
from .forest import ForestParser
3+
from .pmml import PmmlParser

src/sasctl/utils/pyml2ds/connectors/ensembles/forest.py

Lines changed: 0 additions & 89 deletions
This file was deleted.

src/sasctl/utils/pyml2ds/connectors/ensembles/lgb.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __init__(self, booster):
6767
self.out_transform = "1 / (1 + exp(-{0}))"
6868

6969
self._tree_parser = LightgbmTreeParser()
70+
self._tree_parser._features = self._features
7071

7172

7273
def _iter_trees(self):

src/sasctl/utils/pyml2ds/connectors/ensembles/pmml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
except ImportError:
44
objectify = None
55

6-
from pyml2ds.basic.tree import TreeParser
6+
from pyml2ds.basic import TreeParser
77
from .core import EnsembleParser
88

99

src/sasctl/utils/pyml2ds/connectors/ensembles/xgb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from pyml2ds.basic.tree import TreeParser
2+
from pyml2ds.basic import TreeParser
33
from .core import EnsembleParser
44

55

src/sasctl/utils/pyml2ds/pyml2ds.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,11 @@
2020
except ImportError:
2121
lightgbm = None
2222

23-
try:
24-
from sklearn.ensemble import RandomForestClassifier
25-
except ImportError:
26-
RandomForestClassifier = None
27-
2823
from .connectors import XgbParser, LightgbmParser, PmmlParser, ForestParser
2924

3025

3126
def _check_type(model):
32-
comp_types = ["xgboost.sklearn.XGBModel", "lightgbm.LGBMModel", "lightgbm.basic.Booster", "sklearn.ensemble.RandomForestClassifier", "GBM.pmml file"]
27+
comp_types = ["xgboost.sklearn.XGBModel", "lightgbm.LGBMModel", "lightgbm.basic.Booster", "GBM.pmml file"]
3328

3429
parser = None
3530
if xgboost and isinstance(model, xgboost.sklearn.XGBModel):
@@ -43,8 +38,6 @@ def _check_type(model):
4338
parser = LightgbmParser(model)
4439
elif etree and isinstance(model, etree._ElementTree):
4540
parser = PmmlParser(model.getroot())
46-
elif RandomForestClassifier and isinstance(model, RandomForestClassifier):
47-
parser = ForestParser(model)
4841
else:
4942
raise RuntimeError("Unknown booster type: %s. Compatible types are: %s. Check if corresponding library is installed." % type(model).__name__)
5043

tests/pyml2ds_data/gbm.pmml

Lines changed: 18178 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)