Skip to content

Commit 86628f1

Browse files
committed
docs: pattern_based_weighted_mean_metric
1 parent 5967f77 commit 86628f1

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

qolmat/benchmark/comparator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ class Comparator:
4242
"energy": metrics.sum_energy_distances,
4343
"frechet": metrics.frechet_distance,
4444
"dist_corr_pattern": partial(
45-
metrics.pattern_based_metric, metric=metrics.distance_correlation_complement
45+
metrics.pattern_based_weighted_mean_metric,
46+
metric=metrics.distance_correlation_complement,
4647
),
4748
}
4849

qolmat/benchmark/metrics.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -861,15 +861,17 @@ def distance_correlation_complement(
861861
return 1.0 - pd.Series(dcor.distance_correlation(df1.values, df2.values), index=["All"])
862862

863863

864-
def pattern_based_metric(
864+
def pattern_based_weighted_mean_metric(
865865
df1: pd.DataFrame,
866866
df2: pd.DataFrame,
867867
df_mask: pd.DataFrame,
868868
metric: Callable,
869869
min_num_row: int = 10,
870870
**kwargs,
871871
) -> pd.Series:
872-
"""_summary_
872+
"""Compute a mean score based on missing patterns.
873+
Note that for each pattern, a score is returned by the function metric.
874+
This code is based on https://www.statsmodels.org/
873875
874876
Parameters
875877
----------
@@ -889,7 +891,7 @@ def pattern_based_metric(
889891
pd.Series
890892
_description_
891893
"""
892-
# Identify all distinct missing data patterns
894+
# Identify all distinct missing patterns
893895
z = 1 + np.log(1 + np.arange(df_mask.shape[1]))
894896
c = np.dot(df_mask, z)
895897
row_map: Dict = {}
@@ -902,14 +904,18 @@ def pattern_based_metric(
902904
row_map[v].append(i)
903905
patterns = [np.asarray(v) for v in row_map.values()]
904906
scores = []
907+
weights = []
905908
for pattern in patterns:
906909
df1_pattern = df1.iloc[pattern].dropna(axis=1)
907910
if len(df1_pattern.columns) == 0:
908911
df1_pattern = df1.iloc[pattern].dropna(axis=0)
909912

910913
if len(df1_pattern) >= min_num_row:
911914
df2_pattern = df2.loc[df1_pattern.index, df1_pattern.columns]
915+
weights.append(len(df1_pattern))
916+
scores.append(
917+
metric(df1_pattern, df2_pattern, ~df1_pattern.isna(), **kwargs).values[0]
918+
)
912919

913-
scores.append(metric(df1_pattern, df2_pattern, ~df1_pattern.isna(), **kwargs))
914-
915-
return pd.Series(np.mean(scores), index=["All"])
920+
weighted_scores = np.array(scores) * np.array(weights)
921+
return pd.Series(np.mean(weighted_scores), index=["All"])

0 commit comments

Comments
 (0)