88from sklearn import metrics as skm
99from sklearn .ensemble import BaseEnsemble
1010from sklearn .preprocessing import StandardScaler
11+ import dcor
1112
1213EPS = np .finfo (float ).eps
1314
@@ -835,6 +836,98 @@ def frechet_distance(
835836 return pd .Series (np .repeat (frechet_dist , len (df1 .columns )))
836837
837838
839+ def distance_correlation_complement (
840+ df1 : pd .DataFrame , df2 : pd .DataFrame , df_mask : pd .DataFrame
841+ ) -> pd .Series :
842+ """Correlation distance between columns of 2 dataframes.
843+
844+ Parameters
845+ ----------
846+ df1 : pd.DataFrame
847+ true dataframe
848+ df2 : pd.DataFrame
849+ predicted dataframe
850+ df_mask : pd.DataFrame
851+ Elements of the dataframes to compute on
852+
853+ Returns
854+ -------
855+ pd.Series
856+ Correlation distance
857+ """
858+ # For the case that we use this function outside pattern_based_metric
859+ df1 = df1 [df_mask ].fillna (0.0 )
860+ df2 = df2 [df_mask ].fillna (0.0 )
861+
862+ return 1.0 - pd .Series ([dcor .distance_correlation (df1 .values , df2 .values )], index = ["All" ])
863+
864+
865+ def pattern_based_weighted_mean_metric (
866+ df1 : pd .DataFrame ,
867+ df2 : pd .DataFrame ,
868+ df_mask : pd .DataFrame ,
869+ metric : Callable ,
870+ min_num_row : int = 10 ,
871+ ** kwargs ,
872+ ) -> pd .Series :
873+ """Compute a mean score based on missing patterns.
874+ Note that for each pattern, a score is returned by the function metric.
875+ This code is based on https://www.statsmodels.org/
876+
877+ Parameters
878+ ----------
879+ df1 : pd.DataFrame
880+ true dataframe
881+ df2 : pd.DataFrame
882+ predicted dataframe
883+ df_mask : pd.DataFrame
884+ Elements of the dataframes to compute on
885+ metric : Callable
886+ metric function
887+ min_num_row : int, optional
888+ minimum number of row allowed for a pattern without nan, by default 10
889+
890+ Returns
891+ -------
892+ pd.Series
893+ _description_
894+ """
895+ # Identify all distinct missing patterns
896+ z = 1 + np .log (1 + np .arange (df_mask .shape [1 ]))
897+ c = np .dot (df_mask , z )
898+ row_map : Dict = {}
899+ for i , v in enumerate (c ):
900+ if v == 0 :
901+ # No missing values
902+ continue
903+ if v not in row_map :
904+ row_map [v ] = []
905+ row_map [v ].append (i )
906+ patterns = [np .asarray (v ) for v in row_map .values ()]
907+ scores = []
908+ weights = []
909+ for pattern in patterns :
910+ df1_pattern = df1 .iloc [pattern ].dropna (axis = 1 )
911+ if len (df1_pattern .columns ) == 0 :
912+ df1_pattern = df1 .iloc [pattern ].dropna (axis = 0 )
913+
914+ if len (df1_pattern ) >= min_num_row :
915+ df2_pattern = df2 .loc [df1_pattern .index , df1_pattern .columns ]
916+ weights .append (1.0 / len (df1_pattern ))
917+ scores .append (
918+ metric (df1_pattern , df2_pattern , ~ df1_pattern .isna (), ** kwargs ).values [0 ]
919+ )
920+
921+ if len (scores ) == 0 :
922+ raise Exception (
923+ "Not found enough patterns. "
924+ + f"Number of row for each pattern must be larger than min_num_row={ min_num_row } ."
925+ )
926+
927+ weighted_scores = np .array (scores ) * np .array (weights )
928+ return pd .Series (np .sum (weighted_scores ) / np .sum (weights ), index = ["All" ])
929+
930+
838931def get_metric (name : str ) -> Callable :
839932 dict_metrics : Dict [str , Callable ] = {
840933 "mse" : mean_squared_error ,
@@ -849,5 +942,9 @@ def get_metric(name: str) -> Callable:
849942 "pairwise_dist" : sum_pairwise_distances ,
850943 "energy" : sum_energy_distances ,
851944 "frechet" : frechet_distance ,
945+ "dist_corr_pattern" : partial (
946+ pattern_based_weighted_mean_metric ,
947+ metric = distance_correlation_complement ,
948+ ),
852949 }
853950 return dict_metrics [name ]
0 commit comments