Skip to content

Commit 40e900a

Browse files
committed
fix: add subset for in get_errors calls for process_split and process_imputer
1 parent 1233456 commit 40e900a

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

qolmat/benchmark/comparator.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,13 @@ def process_split(
114114
df_with_holes = df_origin.copy()
115115
df_with_holes[df_mask] = np.nan
116116

117+
subset = self.generator_holes.subset
118+
if subset is None:
119+
raise ValueError(
120+
"HoleGenerator `subset` should be overwritten in split "
121+
"but it is none!"
122+
)
123+
117124
split_results = {}
118125
for imputer_name, imputer in self.dict_imputers.items():
119126
dict_config_opti_imputer = self.dict_config_opti.get(
@@ -131,7 +138,9 @@ def process_split(
131138
)
132139

133140
df_imputed = imputer_opti.fit_transform(df_with_holes)
134-
errors = self.get_errors(df_origin, df_imputed, df_mask)
141+
errors = self.get_errors(
142+
df_origin[subset], df_imputed[subset], df_mask[subset]
143+
)
135144
split_results[imputer_name] = errors
136145

137146
return pd.concat(split_results, axis=1)
@@ -154,6 +163,13 @@ def process_imputer(
154163
"""
155164
imputer_name, imputer, all_masks, df_origin = imputer_data
156165

166+
subset = self.generator_holes.subset
167+
if subset is None:
168+
raise ValueError(
169+
"HoleGenerator `subset` should be overwritten in split "
170+
"but it is none!"
171+
)
172+
157173
dict_config_opti_imputer = self.dict_config_opti.get(imputer_name, {})
158174
imputer_opti = hyperparameters.optimize(
159175
imputer,
@@ -170,7 +186,9 @@ def process_imputer(
170186
df_with_holes = df_origin.copy()
171187
df_with_holes[df_mask] = np.nan
172188
df_imputed = imputer_opti.fit_transform(df_with_holes)
173-
errors = self.get_errors(df_origin, df_imputed, df_mask)
189+
errors = self.get_errors(
190+
df_origin[subset], df_imputed[subset], df_mask[subset]
191+
)
174192
imputer_results.append(errors)
175193

176194
return imputer_name, pd.concat(imputer_results).groupby(

0 commit comments

Comments
 (0)