@@ -114,6 +114,13 @@ def process_split(
114114 df_with_holes = df_origin .copy ()
115115 df_with_holes [df_mask ] = np .nan
116116
117+ subset = self .generator_holes .subset
118+ if subset is None :
119+ raise ValueError (
120+ "HoleGenerator `subset` should be overwritten in split "
121+ "but it is none!"
122+ )
123+
117124 split_results = {}
118125 for imputer_name , imputer in self .dict_imputers .items ():
119126 dict_config_opti_imputer = self .dict_config_opti .get (
@@ -131,7 +138,9 @@ def process_split(
131138 )
132139
133140 df_imputed = imputer_opti .fit_transform (df_with_holes )
134- errors = self .get_errors (df_origin , df_imputed , df_mask )
141+ errors = self .get_errors (
142+ df_origin [subset ], df_imputed [subset ], df_mask [subset ]
143+ )
135144 split_results [imputer_name ] = errors
136145
137146 return pd .concat (split_results , axis = 1 )
@@ -154,6 +163,13 @@ def process_imputer(
154163 """
155164 imputer_name , imputer , all_masks , df_origin = imputer_data
156165
166+ subset = self .generator_holes .subset
167+ if subset is None :
168+ raise ValueError (
169+ "HoleGenerator `subset` should be overwritten in split "
170+ "but it is none!"
171+ )
172+
157173 dict_config_opti_imputer = self .dict_config_opti .get (imputer_name , {})
158174 imputer_opti = hyperparameters .optimize (
159175 imputer ,
@@ -170,7 +186,9 @@ def process_imputer(
170186 df_with_holes = df_origin .copy ()
171187 df_with_holes [df_mask ] = np .nan
172188 df_imputed = imputer_opti .fit_transform (df_with_holes )
173- errors = self .get_errors (df_origin , df_imputed , df_mask )
189+ errors = self .get_errors (
190+ df_origin [subset ], df_imputed [subset ], df_mask [subset ]
191+ )
174192 imputer_results .append (errors )
175193
176194 return imputer_name , pd .concat (imputer_results ).groupby (
0 commit comments