Skip to content

Commit 2b74646

Browse files
Merge pull request #51 from Quantmetry/angoho_benchmarks
Feat: add distance correlation and pattern-based metric
2 parents ea44694 + 4f0218a commit 2b74646

File tree

2 files changed

+126
-0
lines changed

2 files changed

+126
-0
lines changed

qolmat/benchmark/metrics.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sklearn import metrics as skm
99
from sklearn.ensemble import BaseEnsemble
1010
from sklearn.preprocessing import StandardScaler
11+
import dcor
1112

1213
EPS = np.finfo(float).eps
1314

@@ -835,6 +836,98 @@ def frechet_distance(
835836
return pd.Series(np.repeat(frechet_dist, len(df1.columns)))
836837

837838

839+
def distance_correlation_complement(
840+
df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFrame
841+
) -> pd.Series:
842+
"""Correlation distance between columns of 2 dataframes.
843+
844+
Parameters
845+
----------
846+
df1 : pd.DataFrame
847+
true dataframe
848+
df2 : pd.DataFrame
849+
predicted dataframe
850+
df_mask : pd.DataFrame
851+
Elements of the dataframes to compute on
852+
853+
Returns
854+
-------
855+
pd.Series
856+
Correlation distance
857+
"""
858+
# For the case that we use this function outside pattern_based_metric
859+
df1 = df1[df_mask].fillna(0.0)
860+
df2 = df2[df_mask].fillna(0.0)
861+
862+
return 1.0 - pd.Series([dcor.distance_correlation(df1.values, df2.values)], index=["All"])
863+
864+
865+
def pattern_based_weighted_mean_metric(
866+
df1: pd.DataFrame,
867+
df2: pd.DataFrame,
868+
df_mask: pd.DataFrame,
869+
metric: Callable,
870+
min_num_row: int = 10,
871+
**kwargs,
872+
) -> pd.Series:
873+
"""Compute a mean score based on missing patterns.
874+
Note that for each pattern, a score is returned by the function metric.
875+
This code is based on https://www.statsmodels.org/
876+
877+
Parameters
878+
----------
879+
df1 : pd.DataFrame
880+
true dataframe
881+
df2 : pd.DataFrame
882+
predicted dataframe
883+
df_mask : pd.DataFrame
884+
Elements of the dataframes to compute on
885+
metric : Callable
886+
metric function
887+
min_num_row : int, optional
888+
minimum number of row allowed for a pattern without nan, by default 10
889+
890+
Returns
891+
-------
892+
pd.Series
893+
_description_
894+
"""
895+
# Identify all distinct missing patterns
896+
z = 1 + np.log(1 + np.arange(df_mask.shape[1]))
897+
c = np.dot(df_mask, z)
898+
row_map: Dict = {}
899+
for i, v in enumerate(c):
900+
if v == 0:
901+
# No missing values
902+
continue
903+
if v not in row_map:
904+
row_map[v] = []
905+
row_map[v].append(i)
906+
patterns = [np.asarray(v) for v in row_map.values()]
907+
scores = []
908+
weights = []
909+
for pattern in patterns:
910+
df1_pattern = df1.iloc[pattern].dropna(axis=1)
911+
if len(df1_pattern.columns) == 0:
912+
df1_pattern = df1.iloc[pattern].dropna(axis=0)
913+
914+
if len(df1_pattern) >= min_num_row:
915+
df2_pattern = df2.loc[df1_pattern.index, df1_pattern.columns]
916+
weights.append(1.0 / len(df1_pattern))
917+
scores.append(
918+
metric(df1_pattern, df2_pattern, ~df1_pattern.isna(), **kwargs).values[0]
919+
)
920+
921+
if len(scores) == 0:
922+
raise Exception(
923+
"Not found enough patterns. "
924+
+ f"Number of row for each pattern must be larger than min_num_row={min_num_row}."
925+
)
926+
927+
weighted_scores = np.array(scores) * np.array(weights)
928+
return pd.Series(np.sum(weighted_scores) / np.sum(weights), index=["All"])
929+
930+
838931
def get_metric(name: str) -> Callable:
839932
dict_metrics: Dict[str, Callable] = {
840933
"mse": mean_squared_error,
@@ -849,5 +942,9 @@ def get_metric(name: str) -> Callable:
849942
"pairwise_dist": sum_pairwise_distances,
850943
"energy": sum_energy_distances,
851944
"frechet": frechet_distance,
945+
"dist_corr_pattern": partial(
946+
pattern_based_weighted_mean_metric,
947+
metric=distance_correlation_complement,
948+
),
852949
}
853950
return dict_metrics[name]

tests/benchmark/test_metrics.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,3 +333,32 @@ def test_value_error_get_correlation_f_oneway_matrix(
333333
assert metrics.mean_diff_corr_matrix_categorical_vs_numerical_features(
334334
df1, df2, df_mask
335335
).equals(pd.Series([np.nan], index=["col1"]))
336+
337+
338+
@pytest.mark.parametrize("df1", [df_incomplete])
339+
@pytest.mark.parametrize("df2", [df_imputed])
340+
@pytest.mark.parametrize("df_mask", [df_mask])
341+
def test_distance_correlation_complement(
342+
df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFrame
343+
) -> None:
344+
result = metrics.distance_correlation_complement(df1, df2, df_mask)
345+
expected = pd.Series([0.001559], index=["All"])
346+
np.testing.assert_allclose(result, expected, atol=1e-3)
347+
348+
349+
@pytest.mark.parametrize("df1", [df_incomplete])
350+
@pytest.mark.parametrize("df2", [df_imputed])
351+
@pytest.mark.parametrize("df_mask", [df_mask])
352+
def test_pattern_based_weighted_mean_metric(
353+
df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFrame
354+
) -> None:
355+
with pytest.raises(Exception):
356+
metrics.pattern_based_weighted_mean_metric(
357+
df1, df2, df_mask, metric=metrics.distance_correlation_complement, min_num_row=5
358+
)
359+
360+
expected = pd.Series([2 / 3], index=["All"])
361+
result = metrics.pattern_based_weighted_mean_metric(
362+
df1, df2, df_mask, metric=metrics.distance_correlation_complement, min_num_row=1
363+
)
364+
np.testing.assert_allclose(result, expected, atol=1e-3)

0 commit comments

Comments
 (0)