@@ -518,3 +518,24 @@ def test_method_error_in_update(monkeypatch: Any, method: str) -> None:
518518 with pytest .raises (ValueError , match = r".*Invalid method.*" ):
519519 mapie_ts_reg .fit (X_toy , y_toy )
520520 mapie_ts_reg .update (X_toy , y_toy )
521+
522+
523+ @pytest .mark .parametrize ("method" , ["enbpi" , "aci" ])
524+ @pytest .mark .parametrize ("cv" , ["split" , "prefit" ])
525+ def test_methods_preservation_in_fit (method : str , cv : str ) -> None :
526+ """Test of enbpi and aci method preservation in the fit MapieRegressor"""
527+
528+ X_train_val , X_test , y_train_val , y_test = train_test_split (
529+ X , y , test_size = 0.33 , random_state = random_state
530+ )
531+ X_train , X_val , y_train , y_val = train_test_split (
532+ X_train_val , y_train_val , test_size = 0.5 , random_state = random_state
533+ )
534+ estimator = LinearRegression ().fit (X_train , y_train )
535+ mapie_ts_reg = MapieTimeSeriesRegressor (
536+ estimator = estimator ,
537+ cv = cv , method = method
538+ )
539+ mapie_ts_reg .fit (X_val , y_val )
540+ mapie_ts_reg .update (X_test , y_test , gamma = 0.1 , alpha = 0.1 )
541+ assert mapie_ts_reg .method == method
0 commit comments