6
6
7
7
8
8
import os
9
- dirname = os .path .dirname
10
9
11
10
import pytest
11
+ from six .moves import mock
12
+
12
13
from sasctl .utils .pyml2ds import pyml2ds
13
14
14
15
16
+ dirname = os .path .dirname
15
17
DATA_PATH = os .path .join (dirname (dirname (__file__ )), 'pyml2ds_data' )
16
18
17
19
@@ -22,7 +24,25 @@ def test_xgb2ds(tmpdir):
22
24
OUT_SAS = os .path .join (str (tmpdir ), 'xgb.sas' )
23
25
EXPECTED_SAS = os .path .join (DATA_PATH , 'xgb.sas' )
24
26
25
- pyml2ds (IN_PKL , OUT_SAS , test = True )
27
+ from sasctl .utils .pyml2ds .connectors .ensembles .xgb import XgbTreeParser
28
+
29
+ # Expected output contains integer values instead of floats.
30
+ # Convert to ensure match.
31
+ class TestXgbTreeParser (XgbTreeParser ):
32
+ def _split_value (self ):
33
+ val = super (TestXgbTreeParser , self )._split_value ()
34
+ return int (float (val ))
35
+
36
+ def _leaf_value (self ):
37
+ val = super (TestXgbTreeParser , self )._leaf_value ()
38
+ return int (float (val ))
39
+
40
+ test_parser = TestXgbTreeParser ()
41
+
42
+ with mock .patch ('sasctl.utils.pyml2ds.connectors.ensembles.xgb.XgbTreeParser' ) as parser :
43
+ parser .return_value = test_parser
44
+ pyml2ds (IN_PKL , OUT_SAS )
45
+
26
46
result = open (OUT_SAS , 'rb' ).read ()
27
47
expected = open (EXPECTED_SAS , 'rb' ).read ()
28
48
assert result == expected
@@ -35,7 +55,26 @@ def test_lgb2ds(tmpdir):
35
55
OUT_SAS = os .path .join (str (tmpdir ), 'lgb.sas' )
36
56
EXPECTED_SAS = os .path .join (DATA_PATH , 'lgb.sas' )
37
57
38
- pyml2ds (IN_PKL , OUT_SAS , test = True )
58
+ from sasctl .utils .pyml2ds .connectors .ensembles .lgb import LightgbmTreeParser
59
+
60
+ # Expected output contains integer values instead of floats.
61
+ # Convert to ensure match.
62
+ class TestLightgbmTreeParser (LightgbmTreeParser ):
63
+ def _split_value (self ):
64
+ val = super (TestLightgbmTreeParser , self )._split_value ()
65
+ return int (float (val ))
66
+
67
+ def _leaf_value (self ):
68
+ val = super (TestLightgbmTreeParser , self )._leaf_value ()
69
+ return int (float (val ))
70
+
71
+ test_parser = TestLightgbmTreeParser ()
72
+
73
+ with mock .patch ('sasctl.utils.pyml2ds.connectors.ensembles.lgb.LightgbmTreeParser' ) as parser :
74
+ parser .return_value = test_parser
75
+ pyml2ds (IN_PKL , OUT_SAS )
76
+
77
+
39
78
result = open (OUT_SAS , 'rb' ).read ()
40
79
expected = open (EXPECTED_SAS , 'rb' ).read ()
41
80
assert result == expected
@@ -46,7 +85,27 @@ def test_gbm2ds(tmpdir):
46
85
OUT_SAS = os .path .join (str (tmpdir ), 'gbm.sas' )
47
86
EXPECTED_SAS = os .path .join (DATA_PATH , 'gbm.sas' )
48
87
49
- pyml2ds (IN_PKL , OUT_SAS , test = True )
88
+ from sasctl .utils .pyml2ds .connectors .ensembles .pmml import PmmlTreeParser
89
+
90
+ # Expected output contains integer values instead of floats.
91
+ # Convert to ensure match.
92
+ class TestPmmlTreeParser (PmmlTreeParser ):
93
+ def _split_value (self ):
94
+ val = super (TestPmmlTreeParser , self )._split_value ()
95
+ return int (float (val ))
96
+
97
+ def _leaf_value (self ):
98
+ val = super (TestPmmlTreeParser , self )._leaf_value ()
99
+ return int (float (val ))
100
+
101
+ test_parser = TestPmmlTreeParser ()
102
+
103
+ with mock .patch ('sasctl.utils.pyml2ds.connectors.ensembles.pmml.PmmlTreeParser' ) as parser :
104
+ parser .return_value = test_parser
105
+ pyml2ds (IN_PKL , OUT_SAS )
106
+
50
107
result = open (OUT_SAS , 'rb' ).read ()
51
108
expected = open (EXPECTED_SAS , 'rb' ).read ()
52
109
assert result == expected
110
+
111
+
0 commit comments