Skip to content

Commit 9d5c2bd

Browse files
Julien RousselJulien Roussel
authored andcommitted
consistency test added
1 parent 1e7c2e4 commit 9d5c2bd

File tree

2 files changed

+56
-10
lines changed

2 files changed

+56
-10
lines changed

qolmat/benchmark/metrics.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,39 @@ def kl_divergence_1D(df1: pd.Series, df2: pd.Series) -> float:
801801
return scipy.stats.entropy(p + EPS, q + EPS)
802802

803803

804+
def kl_divergence_gaussian_exact(
805+
mean1: pd.Series, cov1: pd.DataFrame, mean2: pd.Series, cov2: pd.DataFrame
806+
) -> float:
807+
"""Exact Kullback-Leibler divergence computed between two multivariate normal distributions
808+
809+
Parameters
810+
----------
811+
mean1: pd.Series
812+
Mean of the first distribution
813+
cov1: pd.DataFrame
814+
Covariance matrx of the first distribution
815+
mean2: pd.Series
816+
Mean of the second distribution
817+
cov2: pd.DataFrame
818+
Covariance matrx of the second distribution
819+
Returns
820+
-------
821+
float
822+
Kulback-Leibler divergence
823+
"""
824+
n_variables = len(mean1)
825+
L1, lower1 = scipy.linalg.cho_factor(cov1)
826+
L2, lower2 = scipy.linalg.cho_factor(cov2)
827+
M = scipy.linalg.solve(L2, L1)
828+
y = scipy.linalg.solve(L2, mean2 - mean1)
829+
norm_M = (M**2).sum().sum()
830+
norm_y = (y**2).sum()
831+
term_diag_L = 2 * np.sum(np.log(np.diagonal(L2) / np.diagonal(L1)))
832+
print(norm_M, "-", n_variables, "+", norm_y, "+", term_diag_L)
833+
div_kl = 0.5 * (norm_M - n_variables + norm_y + term_diag_L)
834+
return div_kl
835+
836+
804837
def kl_divergence_gaussian(df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.Series) -> float:
805838
"""Kullback-Leibler divergence estimation based on a Gaussian approximation of both empirical
806839
distributions
@@ -821,20 +854,12 @@ def kl_divergence_gaussian(df1: pd.DataFrame, df2: pd.DataFrame, df_mask: pd.Ser
821854
"""
822855
df1 = df1[df_mask.any(axis=1)]
823856
df2 = df2[df_mask.any(axis=1)]
824-
n_variables = len(df1.columns)
825857
cov1 = df1.cov()
826858
cov2 = df2.cov()
827859
mean1 = df1.mean()
828860
mean2 = df2.mean()
829-
L1, lower1 = scipy.linalg.cho_factor(cov1)
830-
L2, lower2 = scipy.linalg.cho_factor(cov2)
831-
M = scipy.linalg.solve(L2, L1)
832-
y = scipy.linalg.solve(L2, mean2 - mean1)
833-
norm_M = (M**2).sum().sum()
834-
norm_y = (y**2).sum()
835-
term_diag_L = 2 * np.sum(np.log(np.diagonal(L2) / np.diagonal(L1)))
836-
print(norm_M, "-", n_variables, "+", norm_y, "+", term_diag_L)
837-
div_kl = 0.5 * (norm_M - n_variables + norm_y + term_diag_L)
861+
862+
div_kl = kl_divergence_gaussian_exact(mean1, cov1, mean2, cov2)
838863
return div_kl
839864

840865

@@ -1017,6 +1042,10 @@ def pattern_based_weighted_mean_metric(
10171042
scores.append(metric(df1_pattern, df2_pattern, df_mask_pattern, **kwargs))
10181043
if len(scores) == 0:
10191044
raise NotEnoughSamples(max_num_row, min_n_rows)
1045+
print("scores:")
1046+
print(scores)
1047+
print("weights:")
1048+
print(weights)
10201049
return pd.Series(sum([s * w for s, w in zip(scores, weights)]), index=["All"])
10211050

10221051

tests/benchmark/test_metrics.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,3 +360,20 @@ def test_pattern_based_weighted_mean_metric(
360360
df1, df2, df_mask, metric=metrics.distance_anticorr, min_n_rows=1
361361
)
362362
np.testing.assert_allclose(result, expected, rtol=1e-2)
363+
364+
365+
rng = npr.default_rng(123)
366+
df_gauss1 = pd.DataFrame(rng.multivariate_normal([0, 0], [[1, 0.2], [0.2, 2]], size=100))
367+
df_gauss2 = pd.DataFrame(rng.multivariate_normal([0, 1], [[1, 0.2], [0.2, 2]], size=100))
368+
df_mask = pd.DataFrame(np.full_like(df_gauss1, True))
369+
370+
371+
def test_pattern_mae_comparison() -> None:
372+
def fun_mean_mae(df_gauss1, df_gauss2, df_mask) -> float:
373+
return metrics.mean_squared_error(df_gauss1, df_gauss2, df_mask).mean()
374+
375+
result1 = fun_mean_mae(df_gauss1, df_gauss2, df_mask)
376+
result2 = metrics.pattern_based_weighted_mean_metric(
377+
df_gauss1, df_gauss2, df_mask, metric=fun_mean_mae, min_n_rows=1
378+
)
379+
np.testing.assert_allclose(result1, result2, rtol=1e-2)

0 commit comments

Comments
 (0)