Skip to content

Commit bb82d3c

Browse files
committed
feat: add unit tests for pattern_based_weighted_mean_metric and distance_correlation_complement
1 parent 86628f1 commit bb82d3c

File tree

3 files changed

+35
-5
lines changed

3 files changed

+35
-5
lines changed

qolmat/benchmark/metrics.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,7 @@ def distance_correlation_complement(
858858
df1 = df1[df_mask].fillna(0.0)
859859
df2 = df2[df_mask].fillna(0.0)
860860

861-
return 1.0 - pd.Series(dcor.distance_correlation(df1.values, df2.values), index=["All"])
861+
return 1.0 - pd.Series([dcor.distance_correlation(df1.values, df2.values)], index=["All"])
862862

863863

864864
def pattern_based_weighted_mean_metric(
@@ -912,10 +912,16 @@ def pattern_based_weighted_mean_metric(
912912

913913
if len(df1_pattern) >= min_num_row:
914914
df2_pattern = df2.loc[df1_pattern.index, df1_pattern.columns]
915-
weights.append(len(df1_pattern))
915+
weights.append(1.0 / len(df1_pattern))
916916
scores.append(
917917
metric(df1_pattern, df2_pattern, ~df1_pattern.isna(), **kwargs).values[0]
918918
)
919919

920+
if len(scores) == 0:
921+
raise Exception(
922+
"Not found enough patterns. "
923+
+ f"Number of row for each pattern must be larger than min_num_row={min_num_row}."
924+
)
925+
920926
weighted_scores = np.array(scores) * np.array(weights)
921-
return pd.Series(np.mean(weighted_scores), index=["All"])
927+
return pd.Series(np.sum(weighted_scores) / np.sum(weights), index=["All"])

qolmat/benchmark/missing_patterns.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22

33
import functools
44
import logging
5-
from typing import Callable, List, Optional, Tuple, Union
5+
from typing import Callable, List, Optional, Tuple
66

7-
import random
87
import numpy as np
98
import pandas as pd
109
from sklearn.model_selection import GroupShuffleSplit

tests/benchmark/test_metrics.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,3 +333,28 @@ def test_value_error_get_correlation_f_oneway_matrix(
333333
assert metrics.mean_diff_corr_matrix_categorical_vs_numerical_features(
334334
df1, df2, df_mask
335335
).equals(pd.Series([np.nan], index=["col1"]))
336+
337+
338+
@pytest.mark.parametrize("df1", [df_incomplete])
339+
@pytest.mark.parametrize("df2", [df_imputed])
340+
@pytest.mark.parametrize("df_mask", [df_mask])
341+
def test_distance_correlation_complement(
342+
df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFrame
343+
) -> None:
344+
result = metrics.distance_correlation_complement(df1, df2, df_mask)
345+
expected = pd.Series([0.001559], index=["All"])
346+
np.testing.assert_allclose(result, expected, atol=1e-3)
347+
348+
349+
@pytest.mark.parametrize("df1", [df_incomplete])
350+
@pytest.mark.parametrize("df2", [df_imputed])
351+
@pytest.mark.parametrize("df_mask", [df_mask])
352+
def test_pattern_based_weighted_mean_metric(
353+
df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.DataFrame
354+
) -> None:
355+
356+
result = metrics.pattern_based_weighted_mean_metric(
357+
df1, df2, df_mask, metric=metrics.distance_correlation_complement, min_num_row=5
358+
)
359+
expected = pd.Series([2 / 3], index=["All"])
360+
np.testing.assert_allclose(result, expected, atol=1e-3)

0 commit comments

Comments
 (0)