@@ -828,9 +828,13 @@ def test_exceptions(self):
828828 X_train , y_train , X_test , verbose = 25 )
829829
830830 # Internal function model_action
831- assert_raises (ValueError , model_action , LinearRegression (),
832- X_train , y_train , X_test , sample_weight = None ,
831+ assert_raises (ValueError , model_action , LinearRegression (),
832+ X_train , y_train , X_test , sample_weight = None ,
833833 action = 'abc' , transform = None )
834+
835+ # X_test is None when mode != 'oof'
836+ assert_raises (ValueError , stacking , [LinearRegression ()],
837+ X_train , y_train , None , mode = 'oof_pred_bag' )
834838
835839 #---------------------------------------------------------------------------
836840 # Testing parameter warnings
@@ -940,6 +944,80 @@ def test_small_input(self):
940944 assert_array_equal (S_train_1 , S_train_3 )
941945 assert_array_equal (S_test_1 , S_test_3 )
942946
947+ #---------------------------------------------------------------------------
948+ # Mode 'oof', X_test=None
949+ #---------------------------------------------------------------------------
950+
951+ def test_oof_mode_with_none (self ):
952+
953+ model = LinearRegression ()
954+ S_train_1 = cross_val_predict (model , X_train , y = y_train , cv = n_folds ,
955+ n_jobs = 1 , verbose = 0 , method = 'predict' ).reshape (- 1 , 1 )
956+ S_test_1 = None
957+
958+ models = [LinearRegression ()]
959+ S_train_2 , S_test_2 = stacking (models , X_train , y_train , None ,
960+ regression = True , n_folds = n_folds , shuffle = False , save_dir = temp_dir ,
961+ mode = 'oof' , random_state = 0 , verbose = 0 )
962+
963+ # Load OOF from file
964+ # Normally if cleaning is performed there is only one .npy file at given moment
965+ # But if we have no cleaning there may be more then one file so we take the latest
966+ file_name = sorted (glob .glob (os .path .join (temp_dir , '*.npy' )))[- 1 ] # take the latest file
967+ S = np .load (file_name , allow_pickle = True )
968+ S_train_3 = S [0 ]
969+ S_test_3 = S [1 ]
970+
971+ assert_array_equal (S_train_1 , S_train_2 )
972+ assert_array_equal (S_test_1 , S_test_2 )
973+
974+ assert_array_equal (S_train_1 , S_train_3 )
975+ assert_array_equal (S_test_1 , S_test_3 )
976+
977+ #---------------------------------------------------------------------------
978+ # All default values (mode='oof_pred_bag')
979+ #---------------------------------------------------------------------------
980+
981+ def test_all_defaults (self ):
982+
983+ # Override global n_folds=5, because default value in stacking function is 4
984+ n_folds = 4
985+
986+ S_test_temp = np .zeros ((X_test .shape [0 ], n_folds ))
987+ kf = KFold (n_splits = n_folds , shuffle = False , random_state = 0 )
988+ for fold_counter , (tr_index , te_index ) in enumerate (kf .split (X_train , y_train )):
989+ # Split data and target
990+ X_tr = X_train [tr_index ]
991+ y_tr = y_train [tr_index ]
992+ X_te = X_train [te_index ]
993+ y_te = y_train [te_index ]
994+ model = LinearRegression ()
995+ _ = model .fit (X_tr , y_tr )
996+ S_test_temp [:, fold_counter ] = model .predict (X_test )
997+ S_test_1 = np .mean (S_test_temp , axis = 1 ).reshape (- 1 , 1 )
998+
999+ model = LinearRegression ()
1000+ S_train_1 = cross_val_predict (model , X_train , y = y_train , cv = n_folds ,
1001+ n_jobs = 1 , verbose = 0 , method = 'predict' ).reshape (- 1 , 1 )
1002+
1003+ models = [LinearRegression ()]
1004+ S_train_2 , S_test_2 = stacking (models , X_train , y_train , X_test , save_dir = temp_dir )
1005+
1006+ # Load OOF from file
1007+ # Normally if cleaning is performed there is only one .npy file at given moment
1008+ # But if we have no cleaning there may be more then one file so we take the latest
1009+ file_name = sorted (glob .glob (os .path .join (temp_dir , '*.npy' )))[- 1 ] # take the latest file
1010+ S = np .load (file_name , allow_pickle = True )
1011+ S_train_3 = S [0 ]
1012+ S_test_3 = S [1 ]
1013+
1014+ assert_array_equal (S_train_1 , S_train_2 )
1015+ assert_array_equal (S_test_1 , S_test_2 )
1016+
1017+ assert_array_equal (S_train_1 , S_train_3 )
1018+ assert_array_equal (S_test_1 , S_test_3 )
1019+
1020+
9431021#-------------------------------------------------------------------------------
9441022#-------------------------------------------------------------------------------
9451023
0 commit comments