Skip to content

Commit 5967f77

Browse files
committed
fix: move pattern based metrics from missing_patterns to metrics
1 parent 132d864 commit 5967f77

File tree

3 files changed

+61
-114
lines changed

3 files changed

+61
-114
lines changed

qolmat/benchmark/comparator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ class Comparator:
4141
"pairwise_dist": metrics.sum_pairwise_distances,
4242
"energy": metrics.sum_energy_distances,
4343
"frechet": metrics.frechet_distance,
44-
"correlation_dist": partial(metrics.distance_correlation_complement),
44+
"dist_corr_pattern": partial(
45+
metrics.pattern_based_metric, metric=metrics.distance_correlation_complement
46+
),
4547
}
4648

4749
def __init__(

qolmat/benchmark/metrics.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, List, Optional
1+
from typing import Callable, List, Optional, Dict
22

33
import numpy as np
44
import pandas as pd
@@ -854,10 +854,62 @@ def distance_correlation_complement(
854854
pd.Series
855855
Correlation distance
856856
"""
857-
df1[~df_mask] = np.nan
858-
df2[~df_mask] = np.nan
859-
860-
df1 = df1.dropna(axis=0, how="all").dropna(axis=1, how="any")
861-
df2 = df2.dropna(axis=0, how="all").dropna(axis=1, how="any")
857+
# For the case that we use this function outside pattern_based_metric
858+
df1 = df1[df_mask].fillna(0.0)
859+
df2 = df2[df_mask].fillna(0.0)
862860

863861
return 1.0 - pd.Series(dcor.distance_correlation(df1.values, df2.values), index=["All"])
862+
863+
864+
def pattern_based_metric(
865+
df1: pd.DataFrame,
866+
df2: pd.DataFrame,
867+
df_mask: pd.DataFrame,
868+
metric: Callable,
869+
min_num_row: int = 10,
870+
**kwargs,
871+
) -> pd.Series:
872+
"""_summary_
873+
874+
Parameters
875+
----------
876+
df1 : pd.DataFrame
877+
true dataframe
878+
df2 : pd.DataFrame
879+
predicted dataframe
880+
df_mask : pd.DataFrame
881+
Elements of the dataframes to compute on
882+
metric : Callable
883+
metric function
884+
min_num_row : int, optional
885+
minimum number of row allowed for a pattern without nan, by default 10
886+
887+
Returns
888+
-------
889+
pd.Series
890+
_description_
891+
"""
892+
# Identify all distinct missing data patterns
893+
z = 1 + np.log(1 + np.arange(df_mask.shape[1]))
894+
c = np.dot(df_mask, z)
895+
row_map: Dict = {}
896+
for i, v in enumerate(c):
897+
if v == 0:
898+
# No missing values
899+
continue
900+
if v not in row_map:
901+
row_map[v] = []
902+
row_map[v].append(i)
903+
patterns = [np.asarray(v) for v in row_map.values()]
904+
scores = []
905+
for pattern in patterns:
906+
df1_pattern = df1.iloc[pattern].dropna(axis=1)
907+
if len(df1_pattern.columns) == 0:
908+
df1_pattern = df1.iloc[pattern].dropna(axis=0)
909+
910+
if len(df1_pattern) >= min_num_row:
911+
df2_pattern = df2.loc[df1_pattern.index, df1_pattern.columns]
912+
913+
scores.append(metric(df1_pattern, df2_pattern, ~df1_pattern.isna(), **kwargs))
914+
915+
return pd.Series(np.mean(scores), index=["All"])

qolmat/benchmark/missing_patterns.py

Lines changed: 0 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -699,110 +699,3 @@ def split(self, X: pd.DataFrame) -> List[pd.DataFrame]:
699699
list_masks.append(df_mask)
700700

701701
return list_masks
702-
703-
704-
class PatternHoleGenerator(_HoleGenerator):
705-
"""This class implements a way to generate holes in a dataframe.
706-
The holes are generated from pattern, specified by the user.
707-
708-
Parameters
709-
----------
710-
n_splits : int
711-
Number of splits
712-
subset : Optional[List[str]], optional
713-
Names of the columns for which holes must be created, by default None
714-
ratio_masked : Optional[float], optional
715-
Ratio of masked values ​​to add, by default 0.05.
716-
random_state : Optional[int], optional
717-
The seed used by the random number generator, by default 42.
718-
groups : List[str]
719-
Names of the columns forming the groups, by default []
720-
"""
721-
722-
def __init__(
723-
self,
724-
n_splits: int,
725-
ratio_masked: float = 0.05,
726-
random_state: Union[None, int, np.random.RandomState] = None,
727-
groups: List[str] = [],
728-
):
729-
super().__init__(
730-
n_splits=n_splits,
731-
subset=None,
732-
random_state=random_state,
733-
ratio_masked=ratio_masked,
734-
groups=groups,
735-
)
736-
737-
def fit(self, X: pd.DataFrame) -> PatternHoleGenerator:
738-
"""Creare the groups based on the column names (groups attribute)
739-
740-
Parameters
741-
----------
742-
X : pd.DataFrame
743-
744-
Returns
745-
-------
746-
PatternHoleGenerator
747-
The model itself
748-
749-
Raises
750-
------
751-
if the number of samples/splits is greater than the number of groups.
752-
"""
753-
754-
super().fit(X)
755-
df_isna = X.isna().apply(lambda x: self.get_pattern(x), axis=1).to_frame(name="pattern")
756-
self.df_isna = df_isna["pattern"]
757-
self.patterns_counts = self.df_isna.value_counts()
758-
patterns = self.patterns_counts.index.to_list()
759-
if "_ALLNAN_" in patterns:
760-
patterns.remove("_ALLNAN_")
761-
if "_EMPTY_" in patterns:
762-
patterns.remove("_EMPTY_")
763-
self.patterns = patterns
764-
765-
return self
766-
767-
def generate_mask(self, X: pd.DataFrame) -> pd.DataFrame:
768-
"""
769-
Returns a mask for the dataframe at hand.
770-
771-
Parameters
772-
----------
773-
X : pd.DataFrame
774-
Initial dataframe with a missing pattern to be imitated.
775-
"""
776-
if self.ngroups is not None:
777-
self.fit(X)
778-
779-
df_mask = pd.DataFrame(False, index=X.index, columns=X.columns)
780-
patterns_selected = ["_EMPTY_"]
781-
patterns = self.patterns
782-
for k in range(len(self.patterns)):
783-
pattern = random.choice(patterns)
784-
patterns_selected_ = patterns_selected + [pattern]
785-
patterns.remove(pattern)
786-
df_mask_ = df_mask.copy()
787-
X_ = X.copy()
788-
789-
df_mask_.loc[self.df_isna[self.df_isna.isin(patterns_selected_)].index] = True
790-
X_[~df_mask_] = np.nan
791-
X_ = X_.dropna(axis=0, how="all").dropna(axis=1, how="any")
792-
if X_.size == 0:
793-
break
794-
patterns_selected.append(pattern)
795-
if self.patterns_counts.loc[patterns_selected_].sum() / len(X) > self.ratio_masked:
796-
break
797-
798-
df_mask.loc[self.df_isna[self.df_isna.isin(patterns_selected)].index] = True
799-
return df_mask
800-
801-
def get_pattern(self, row: pd.Series) -> str:
802-
list_col_pattern = [col for col in row.index.to_list() if row[col] == True]
803-
if len(list_col_pattern) == 0:
804-
return "_EMPTY_"
805-
elif len(list_col_pattern) == row.index.size:
806-
return "_ALLNAN_"
807-
else:
808-
return "__".join(list_col_pattern)

0 commit comments

Comments
 (0)