Skip to content

Commit 34fbc10

Browse files
Julien RousselJulien Roussel
authored andcommitted
pattern simplified
1 parent 4f0218a commit 34fbc10

File tree

1 file changed

+10
-28
lines changed

1 file changed

+10
-28
lines changed

qolmat/benchmark/metrics.py

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -892,40 +892,22 @@ def pattern_based_weighted_mean_metric(
892892
pd.Series
893893
_description_
894894
"""
895-
# Identify all distinct missing patterns
896-
z = 1 + np.log(1 + np.arange(df_mask.shape[1]))
897-
c = np.dot(df_mask, z)
898-
row_map: Dict = {}
899-
for i, v in enumerate(c):
900-
if v == 0:
901-
# No missing values
902-
continue
903-
if v not in row_map:
904-
row_map[v] = []
905-
row_map[v].append(i)
906-
patterns = [np.asarray(v) for v in row_map.values()]
907895
scores = []
908896
weights = []
909-
for pattern in patterns:
910-
df1_pattern = df1.iloc[pattern].dropna(axis=1)
911-
if len(df1_pattern.columns) == 0:
912-
df1_pattern = df1.iloc[pattern].dropna(axis=0)
913-
914-
if len(df1_pattern) >= min_num_row:
915-
df2_pattern = df2.loc[df1_pattern.index, df1_pattern.columns]
916-
weights.append(1.0 / len(df1_pattern))
917-
scores.append(
918-
metric(df1_pattern, df2_pattern, ~df1_pattern.isna(), **kwargs).values[0]
919-
)
920-
897+
for tup_pattern, df_mask_pattern in df_mask.groupby(df_mask.columns.tolist()):
898+
ind_pattern = df_mask_pattern.index
899+
df1_pattern = df1.loc[ind_pattern, list(tup_pattern)]
900+
if not any(tup_pattern) or len(df1_pattern) < min_num_row:
901+
continue
902+
df2_pattern = df2.loc[ind_pattern, list(tup_pattern)]
903+
weights.append(1.0 / len(df1_pattern))
904+
scores.append(metric(df1_pattern, df2_pattern, ~df1_pattern.isna(), **kwargs))
921905
if len(scores) == 0:
922906
raise Exception(
923907
"Not found enough patterns. "
924-
+ f"Number of row for each pattern must be larger than min_num_row={min_num_row}."
908+
f"Number of row for each pattern must be larger than min_num_row={min_num_row}."
925909
)
926-
927-
weighted_scores = np.array(scores) * np.array(weights)
928-
return pd.Series(np.sum(weighted_scores) / np.sum(weights), index=["All"])
910+
return pd.Series(sum([s * w for s, w in zip(scores, weights)]) / sum(weights), index=["All"])
929911

930912

931913
def get_metric(name: str) -> Callable:

0 commit comments

Comments
 (0)