99from .testing_utils import (if_statsmodels , if_pandas , if_patsy ,
1010 if_environ_has , assert_list_almost_equal_value ,
1111 assert_list_almost_equal ,
12- if_sklearn_version_greater_than_or_equal_to )
12+ if_sklearn_version_greater_than_or_equal_to ,
13+ if_platform_not_win_32 )
1314from nose .tools import (assert_equal , assert_true , assert_almost_equal ,
1415 assert_list_equal , assert_raises , assert_not_equal )
1516import numpy
1920 HingeBasisFunction , LinearBasisFunction )
2021from pyearth import Earth
2122import pyearth
23+ from numpy .testing .utils import assert_array_almost_equal
2224
23- numpy . random . seed ( 0 )
25+ regenerate_target_files = False
2426
27+ numpy .random .seed (1 )
2528basis = Basis (10 )
2629constant = ConstantBasisFunction ()
2730basis .append (constant )
3134basis .append (bf1 )
3235basis .append (bf2 )
3336basis .append (bf3 )
34- X = numpy .random .normal (size = (100 , 10 ))
37+ X = numpy .random .normal (size = (1000 , 10 ))
3538missing = numpy .zeros_like (X , dtype = BOOL )
36- B = numpy .empty (shape = (100 , 4 ), dtype = numpy .float64 )
39+ B = numpy .empty (shape = (1000 , 4 ), dtype = numpy .float64 )
3740basis .transform (X , missing , B )
3841beta = numpy .random .normal (size = 4 )
39- y = numpy .empty (shape = 100 , dtype = numpy .float64 )
40- y [:] = numpy .dot (B , beta ) + numpy .random .normal (size = 100 )
42+ y = numpy .empty (shape = 1000 , dtype = numpy .float64 )
43+ y [:] = numpy .dot (B , beta ) + numpy .random .normal (size = 1000 )
4144default_params = {"penalty" : 1 }
4245
43-
46+ @ if_platform_not_win_32
4447@if_sklearn_version_greater_than_or_equal_to ('0.17.2' )
4548def test_check_estimator ():
49+ numpy .random .seed (0 )
4650 import sklearn .utils .estimator_checks
4751 sklearn .utils .estimator_checks .MULTI_OUTPUT .append ('Earth' )
4852 sklearn .utils .estimator_checks .check_estimator (Earth )
@@ -149,6 +153,7 @@ def test_output_weight():
149153
150154
151155def test_missing_data ():
156+ numpy .random .seed (0 )
152157 earth = Earth (allow_missing = True , ** default_params )
153158 missing_ = numpy .random .binomial (1 , .05 , X .shape ).astype (bool )
154159 X_ = X .copy ()
@@ -157,34 +162,42 @@ def test_missing_data():
157162 res = str (earth .score (X_ , y ))
158163 filename = os .path .join (os .path .dirname (__file__ ),
159164 'earth_regress_missing_data.txt' )
160- # with open(filename, 'w') as fl:
161- # fl.write(res)
165+ if regenerate_target_files :
166+ with open (filename , 'w' ) as fl :
167+ fl .write (res )
162168 with open (filename , 'r' ) as fl :
163169 prev = fl .read ()
164- assert_true (abs (float (res ) - float (prev )) < .03 )
165-
170+ try :
171+ assert_true (abs (float (res ) - float (prev )) < .03 )
172+ except AssertionError :
173+ print ('Got %f, %f' % (float (res ), float (prev )))
174+ raise
166175
167176def test_fit ():
177+ numpy .random .seed (0 )
168178 earth = Earth (** default_params )
169179 earth .fit (X , y )
170180 res = str (earth .rsq_ )
171181 filename = os .path .join (os .path .dirname (__file__ ),
172182 'earth_regress.txt' )
173- # with open(filename, 'w') as fl:
174- # fl.write(res)
183+ if regenerate_target_files :
184+ with open (filename , 'w' ) as fl :
185+ fl .write (res )
175186 with open (filename , 'r' ) as fl :
176187 prev = fl .read ()
177188 assert_true (abs (float (res ) - float (prev )) < .05 )
178189
179190
180191def test_smooth ():
192+ numpy .random .seed (0 )
181193 model = Earth (penalty = 1 , smooth = True )
182194 model .fit (X , y )
183195 res = str (model .rsq_ )
184196 filename = os .path .join (os .path .dirname (__file__ ),
185197 'earth_regress_smooth.txt' )
186- # with open(filename, 'w') as fl:
187- # fl.write(res)
198+ if regenerate_target_files :
199+ with open (filename , 'w' ) as fl :
200+ fl .write (res )
188201 with open (filename , 'r' ) as fl :
189202 prev = fl .read ()
190203 assert_true (abs (float (res ) - float (prev )) < .05 )
@@ -193,11 +206,12 @@ def test_smooth():
193206def test_linvars ():
194207 earth = Earth (** default_params )
195208 earth .fit (X , y , linvars = [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ])
196- res = str (earth .trace ()) + ' \n ' + earth . summary ( )
209+ res = str (earth .rsq_ )
197210 filename = os .path .join (os .path .dirname (__file__ ),
198211 'earth_linvars_regress.txt' )
199- # with open(filename, 'w') as fl:
200- # fl.write(res)
212+ if regenerate_target_files :
213+ with open (filename , 'w' ) as fl :
214+ fl .write (res )
201215 with open (filename , 'r' ) as fl :
202216 prev = fl .read ()
203217
@@ -301,8 +315,7 @@ def test_pickle_compatibility():
301315 model = earth .fit (X , y )
302316 model_copy = pickle .loads (pickle .dumps (model ))
303317 assert_true (model_copy == model )
304- assert_true (
305- numpy .all (model .predict (X ) == model_copy .predict (X )))
318+ assert_array_almost_equal (model .predict (X ), model_copy .predict (X ))
306319 assert_true (model .basis_ [0 ] is model .basis_ [1 ]._get_root ())
307320 assert_true (model_copy .basis_ [0 ] is model_copy .basis_ [1 ]._get_root ())
308321
@@ -318,11 +331,11 @@ def test_pickle_version_storage():
318331
319332
320333def test_copy_compatibility ():
334+ numpy .random .seed (0 )
321335 model = Earth (** default_params ).fit (X , y )
322336 model_copy = copy .copy (model )
323337 assert_true (model_copy == model )
324- assert_true (
325- numpy .all (model .predict (X ) == model_copy .predict (X )))
338+ assert_array_almost_equal (model .predict (X ), model_copy .predict (X ))
326339 assert_true (model .basis_ [0 ] is model .basis_ [1 ]._get_root ())
327340 assert_true (model_copy .basis_ [0 ] is model_copy .basis_ [1 ]._get_root ())
328341
0 commit comments