Skip to content

Commit 07d0538

Browse files
committed
feat: ready to merge
1 parent ddd00de commit 07d0538

File tree

2 files changed

+1
-109
lines changed

2 files changed

+1
-109
lines changed

qolmat/benchmark/comparator.py

Lines changed: 0 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -165,111 +165,3 @@ def compare(
165165
df_errors = pd.DataFrame(dict_errors)
166166

167167
return df_errors
168-
169-
170-
class ComparatorBasedPattern(Comparator):
171-
def __init__(
172-
self,
173-
dict_models: Dict[str, Any],
174-
selected_columns: List[str],
175-
generator_holes: _HoleGenerator,
176-
metrics: List = ["mae", "wmape", "KL"],
177-
search_params: Optional[Dict[str, Dict[str, Union[float, int, str]]]] = {},
178-
n_calls_opt: int = 10,
179-
num_patterns: int = 5,
180-
):
181-
super().__init__(
182-
dict_models=dict_models,
183-
selected_columns=selected_columns,
184-
generator_holes=generator_holes,
185-
metrics=metrics,
186-
search_params=search_params,
187-
n_calls_opt=n_calls_opt,
188-
)
189-
190-
self.num_patterns = num_patterns
191-
192-
def evaluate_errors_sample(
193-
self,
194-
imputer: Any,
195-
df: pd.DataFrame,
196-
list_spaces: List[Dict] = [],
197-
) -> pd.Series:
198-
"""Evaluate the errors in the cross-validation
199-
200-
Parameters
201-
----------
202-
tested_model : any
203-
imputation model
204-
df : pd.DataFrame
205-
dataframe to impute
206-
search_space : Dict
207-
search space for tested_model's hyperparameters
208-
209-
Returns
210-
-------
211-
pd.DataFrame
212-
DataFrame with the errors for each metric (in column) and at each fold (in index)
213-
"""
214-
215-
list_errors = []
216-
df_origin = df[self.selected_columns].copy()
217-
dfs_pattern = self.get_df_based_pattern(df_origin)
218-
weights = []
219-
# Fit then split, or fit in split ?
220-
self.generator_holes.fit(df_origin)
221-
for df_pattern in dfs_pattern:
222-
# Get all columns in pattern
223-
cols_pattern = df_pattern.dropna(axis=1).columns
224-
for df_mask in self.generator_holes.split(df_pattern):
225-
weights.append(len(df_pattern))
226-
df_corrupted = df_pattern.copy()
227-
df_corrupted[df_mask] = np.nan
228-
229-
if list_spaces:
230-
cv = cross_validation.CrossValidation(
231-
imputer,
232-
list_spaces=list_spaces,
233-
hole_generator=self.generator_holes,
234-
n_calls=self.n_calls_opt,
235-
)
236-
df_imputed = cv.fit_transform(df_corrupted)
237-
else:
238-
df_imputed = imputer.fit_transform(df_corrupted)
239-
240-
subset = self.generator_holes.subset # columns selected
241-
subset = [col for col in subset if col in cols_pattern]
242-
errors = self.get_errors(df_pattern[subset], df_imputed[subset], df_mask[subset])
243-
list_errors.append(errors)
244-
245-
df_errors = pd.DataFrame(list_errors)
246-
# Weighted errors
247-
errors_mean = df_errors.apply(
248-
lambda x: (x * np.array(weights)).sum() / np.sum(weights), axis=0
249-
)
250-
return errors_mean.sort_index()
251-
252-
def get_df_based_pattern(self, df: pd.DataFrame) -> List[pd.DataFrame]:
253-
def get_pattern(row):
254-
list_col_pattern = [col for col in row.index.to_list() if row[col] == True]
255-
if len(list_col_pattern) == 0:
256-
return "_EMPTY_"
257-
elif len(list_col_pattern) == row.index.size:
258-
return "_ALLNAN_"
259-
else:
260-
return "_".join(list_col_pattern)
261-
262-
df_isna = df.isna().apply(lambda x: get_pattern(x), axis=1).to_frame(name="pattern")
263-
df_isna_pattern = df_isna["pattern"].value_counts()
264-
265-
patterns = df_isna_pattern.index.to_list()
266-
patterns.remove("_ALLNAN_")
267-
patterns.remove("_EMPTY_")
268-
269-
dfs = []
270-
for idx_pattern in range(min(len(patterns), self.num_patterns)):
271-
patterns_selected = ["_EMPTY_"] + [patterns[idx_pattern]]
272-
df_pattern = df.loc[df_isna[df_isna["pattern"].isin(patterns_selected)].index]
273-
dfs.append(df_pattern)
274-
275-
return dfs

qolmat/benchmark/missing_patterns.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def split(self, X: pd.DataFrame) -> List[pd.DataFrame]:
111111
mask
112112
"""
113113

114-
# self.fit(X)
114+
self.fit(X)
115115
list_masks = []
116116
for _ in range(self.n_splits):
117117
if self.ngroups is None:

0 commit comments

Comments
 (0)