17
17
DATA_PATH = os .path .join (dirname (dirname (__file__ )), 'pyml2ds_data' )
18
18
19
19
20
- def test_xgb2ds (tmpdir ):
20
+ def test_xgb2ds ():
21
21
pytest .importorskip ('xgboost' )
22
22
23
23
IN_PKL = os .path .join (DATA_PATH , 'xgb.pkl' )
24
- OUT_SAS = os .path .join (str (tmpdir ), 'xgb.sas' )
25
24
EXPECTED_SAS = os .path .join (DATA_PATH , 'xgb.sas' )
26
25
27
26
from sasctl .utils .pyml2ds .connectors .ensembles .xgb import XgbTreeParser
@@ -41,18 +40,16 @@ def _leaf_value(self):
41
40
42
41
with mock .patch ('sasctl.utils.pyml2ds.connectors.ensembles.xgb.XgbTreeParser' ) as parser :
43
42
parser .return_value = test_parser
44
- pyml2ds (IN_PKL , OUT_SAS )
43
+ result = pyml2ds (IN_PKL )
45
44
46
- result = open (OUT_SAS , 'rb' ).read ()
47
- expected = open (EXPECTED_SAS , 'rb' ).read ()
45
+ expected = open (EXPECTED_SAS , 'r' ).read ()
48
46
assert result == expected
49
47
50
48
51
- def test_lgb2ds (tmpdir ):
49
+ def test_lgb2ds ():
52
50
pytest .importorskip ('lightgbm' )
53
51
54
52
IN_PKL = os .path .join (DATA_PATH , 'lgb.pkl' )
55
- OUT_SAS = os .path .join (str (tmpdir ), 'lgb.sas' )
56
53
EXPECTED_SAS = os .path .join (DATA_PATH , 'lgb.sas' )
57
54
58
55
from sasctl .utils .pyml2ds .connectors .ensembles .lgb import LightgbmTreeParser
@@ -72,17 +69,14 @@ def _leaf_value(self):
72
69
73
70
with mock .patch ('sasctl.utils.pyml2ds.connectors.ensembles.lgb.LightgbmTreeParser' ) as parser :
74
71
parser .return_value = test_parser
75
- pyml2ds (IN_PKL , OUT_SAS )
72
+ result = pyml2ds (IN_PKL )
76
73
77
-
78
- result = open (OUT_SAS , 'rb' ).read ()
79
- expected = open (EXPECTED_SAS , 'rb' ).read ()
74
+ expected = open (EXPECTED_SAS , 'r' ).read ()
80
75
assert result == expected
81
76
82
77
83
- def test_gbm2ds (tmpdir ):
78
+ def test_gbm2ds ():
84
79
IN_PKL = os .path .join (DATA_PATH , 'gbm.pmml' )
85
- OUT_SAS = os .path .join (str (tmpdir ), 'gbm.sas' )
86
80
EXPECTED_SAS = os .path .join (DATA_PATH , 'gbm.sas' )
87
81
88
82
from sasctl .utils .pyml2ds .connectors .ensembles .pmml import PmmlTreeParser
@@ -102,10 +96,9 @@ def _leaf_value(self):
102
96
103
97
with mock .patch ('sasctl.utils.pyml2ds.connectors.ensembles.pmml.PmmlTreeParser' ) as parser :
104
98
parser .return_value = test_parser
105
- pyml2ds (IN_PKL , OUT_SAS )
99
+ result = pyml2ds (IN_PKL )
106
100
107
- result = open (OUT_SAS , 'rb' ).read ()
108
- expected = open (EXPECTED_SAS , 'rb' ).read ()
101
+ expected = open (EXPECTED_SAS , 'r' ).read ()
109
102
assert result == expected
110
103
111
104
0 commit comments