Skip to content

Commit 132d864

Browse files
committed
fix: PatternHoleGenerator
1 parent f72d385 commit 132d864

File tree

2 files changed

+52
-31
lines changed

2 files changed

+52
-31
lines changed

qolmat/benchmark/metrics.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -384,32 +384,6 @@ def kolmogorov_smirnov_test(
384384
)
385385

386386

387-
def distance_correlation_complement(
388-
df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFrame
389-
) -> pd.Series:
390-
"""Correlation distance between columns of 2 dataframes.
391-
392-
Parameters
393-
----------
394-
df1 : pd.DataFrame
395-
true dataframe
396-
df2 : pd.DataFrame
397-
predicted dataframe
398-
df_mask : pd.DataFrame
399-
Elements of the dataframes to compute on
400-
401-
Returns
402-
-------
403-
pd.Series
404-
Correlation distance
405-
"""
406-
407-
df1 = df1.fillna(0.0)
408-
df2 = df2[["TEMP", "PRES"]].fillna(0.0)
409-
print(df1.shape, df2.shape)
410-
return 1.0 - pd.Series(dcor.distance_correlation(df1.values, df2.values), index=["All"])
411-
412-
413387
def _total_variance_distance_1D(df1: pd.Series, df2: pd.Series) -> float:
414388
"""Compute Total Variance Distance for a categorical feature
415389
It is based on TVComplement in https://github.com/sdv-dev/SDMetrics
@@ -859,3 +833,31 @@ def frechet_distance(
859833
return pd.Series((frechet_dist / df_true.shape[0]), index=["All"])
860834
else:
861835
return pd.Series(np.repeat(frechet_dist, len(df1.columns)))
836+
837+
838+
def distance_correlation_complement(
839+
df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFrame
840+
) -> pd.Series:
841+
"""Correlation distance between columns of 2 dataframes.
842+
843+
Parameters
844+
----------
845+
df1 : pd.DataFrame
846+
true dataframe
847+
df2 : pd.DataFrame
848+
predicted dataframe
849+
df_mask : pd.DataFrame
850+
Elements of the dataframes to compute on
851+
852+
Returns
853+
-------
854+
pd.Series
855+
Correlation distance
856+
"""
857+
df1[~df_mask] = np.nan
858+
df2[~df_mask] = np.nan
859+
860+
df1 = df1.dropna(axis=0, how="all").dropna(axis=1, how="any")
861+
df2 = df2.dropna(axis=0, how="all").dropna(axis=1, how="any")
862+
863+
return 1.0 - pd.Series(dcor.distance_correlation(df1.values, df2.values), index=["All"])

qolmat/benchmark/missing_patterns.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
from typing import Callable, List, Optional, Tuple, Union
66

7+
import random
78
import numpy as np
89
import pandas as pd
910
from sklearn.model_selection import GroupShuffleSplit
@@ -753,8 +754,8 @@ def fit(self, X: pd.DataFrame) -> PatternHoleGenerator:
753754
super().fit(X)
754755
df_isna = X.isna().apply(lambda x: self.get_pattern(x), axis=1).to_frame(name="pattern")
755756
self.df_isna = df_isna["pattern"]
756-
757-
patterns = self.df_isna.value_counts().index.to_list()
757+
self.patterns_counts = self.df_isna.value_counts()
758+
patterns = self.patterns_counts.index.to_list()
758759
if "_ALLNAN_" in patterns:
759760
patterns.remove("_ALLNAN_")
760761
if "_EMPTY_" in patterns:
@@ -772,11 +773,29 @@ def generate_mask(self, X: pd.DataFrame) -> pd.DataFrame:
772773
X : pd.DataFrame
773774
Initial dataframe with a missing pattern to be imitated.
774775
"""
776+
if self.ngroups is not None:
777+
self.fit(X)
778+
775779
df_mask = pd.DataFrame(False, index=X.index, columns=X.columns)
776-
patterns_selected = ["_EMPTY_"] + [self.patterns[0]]
777-
df_mask.loc[self.df_isna[self.df_isna.isin(patterns_selected)].index] = True
778-
df_mask[X.isna()] = False
780+
patterns_selected = ["_EMPTY_"]
781+
patterns = self.patterns
782+
for k in range(len(self.patterns)):
783+
pattern = random.choice(patterns)
784+
patterns_selected_ = patterns_selected + [pattern]
785+
patterns.remove(pattern)
786+
df_mask_ = df_mask.copy()
787+
X_ = X.copy()
788+
789+
df_mask_.loc[self.df_isna[self.df_isna.isin(patterns_selected_)].index] = True
790+
X_[~df_mask_] = np.nan
791+
X_ = X_.dropna(axis=0, how="all").dropna(axis=1, how="any")
792+
if X_.size == 0:
793+
break
794+
patterns_selected.append(pattern)
795+
if self.patterns_counts.loc[patterns_selected_].sum() / len(X) > self.ratio_masked:
796+
break
779797

798+
df_mask.loc[self.df_isna[self.df_isna.isin(patterns_selected)].index] = True
780799
return df_mask
781800

782801
def get_pattern(self, row: pd.Series) -> str:

0 commit comments

Comments
 (0)