@@ -892,40 +892,22 @@ def pattern_based_weighted_mean_metric(
892892 pd.Series
893893 _description_
894894 """
895- # Identify all distinct missing patterns
896- z = 1 + np .log (1 + np .arange (df_mask .shape [1 ]))
897- c = np .dot (df_mask , z )
898- row_map : Dict = {}
899- for i , v in enumerate (c ):
900- if v == 0 :
901- # No missing values
902- continue
903- if v not in row_map :
904- row_map [v ] = []
905- row_map [v ].append (i )
906- patterns = [np .asarray (v ) for v in row_map .values ()]
907895 scores = []
908896 weights = []
909- for pattern in patterns :
910- df1_pattern = df1 .iloc [pattern ].dropna (axis = 1 )
911- if len (df1_pattern .columns ) == 0 :
912- df1_pattern = df1 .iloc [pattern ].dropna (axis = 0 )
913-
914- if len (df1_pattern ) >= min_num_row :
915- df2_pattern = df2 .loc [df1_pattern .index , df1_pattern .columns ]
916- weights .append (1.0 / len (df1_pattern ))
917- scores .append (
918- metric (df1_pattern , df2_pattern , ~ df1_pattern .isna (), ** kwargs ).values [0 ]
919- )
920-
897+ for tup_pattern , df_mask_pattern in df_mask .groupby (df_mask .columns .tolist ()):
898+ ind_pattern = df_mask_pattern .index
899+ df1_pattern = df1 .loc [ind_pattern , list (tup_pattern )]
900+ if not any (tup_pattern ) or len (df1_pattern ) < min_num_row :
901+ continue
902+ df2_pattern = df2 .loc [ind_pattern , list (tup_pattern )]
903+ weights .append (1.0 / len (df1_pattern ))
904+ scores .append (metric (df1_pattern , df2_pattern , ~ df1_pattern .isna (), ** kwargs ))
921905 if len (scores ) == 0 :
922906 raise Exception (
923907 "Not found enough patterns. "
924- + f"Number of row for each pattern must be larger than min_num_row={ min_num_row } ."
908+ f"Number of row for each pattern must be larger than min_num_row={ min_num_row } ."
925909 )
926-
927- weighted_scores = np .array (scores ) * np .array (weights )
928- return pd .Series (np .sum (weighted_scores ) / np .sum (weights ), index = ["All" ])
910+ return pd .Series (sum ([s * w for s , w in zip (scores , weights )]) / sum (weights ), index = ["All" ])
929911
930912
931913def get_metric (name : str ) -> Callable :
0 commit comments