44import logging
55from typing import Callable , List , Optional , Tuple , Union
66
7+ import random
78import numpy as np
89import pandas as pd
910from sklearn .model_selection import GroupShuffleSplit
@@ -753,8 +754,8 @@ def fit(self, X: pd.DataFrame) -> PatternHoleGenerator:
753754 super ().fit (X )
754755 df_isna = X .isna ().apply (lambda x : self .get_pattern (x ), axis = 1 ).to_frame (name = "pattern" )
755756 self .df_isna = df_isna ["pattern" ]
756-
757- patterns = self .df_isna . value_counts () .index .to_list ()
757+ self . patterns_counts = self . df_isna . value_counts ()
758+ patterns = self .patterns_counts .index .to_list ()
758759 if "_ALLNAN_" in patterns :
759760 patterns .remove ("_ALLNAN_" )
760761 if "_EMPTY_" in patterns :
@@ -772,11 +773,29 @@ def generate_mask(self, X: pd.DataFrame) -> pd.DataFrame:
772773 X : pd.DataFrame
773774 Initial dataframe with a missing pattern to be imitated.
774775 """
776+ if self .ngroups is not None :
777+ self .fit (X )
778+
775779 df_mask = pd .DataFrame (False , index = X .index , columns = X .columns )
776- patterns_selected = ["_EMPTY_" ] + [self .patterns [0 ]]
777- df_mask .loc [self .df_isna [self .df_isna .isin (patterns_selected )].index ] = True
778- df_mask [X .isna ()] = False
780+ patterns_selected = ["_EMPTY_" ]
781+ patterns = self .patterns
782+ for k in range (len (self .patterns )):
783+ pattern = random .choice (patterns )
784+ patterns_selected_ = patterns_selected + [pattern ]
785+ patterns .remove (pattern )
786+ df_mask_ = df_mask .copy ()
787+ X_ = X .copy ()
788+
789+ df_mask_ .loc [self .df_isna [self .df_isna .isin (patterns_selected_ )].index ] = True
790+ X_ [~ df_mask_ ] = np .nan
791+ X_ = X_ .dropna (axis = 0 , how = "all" ).dropna (axis = 1 , how = "any" )
792+ if X_ .size == 0 :
793+ break
794+ patterns_selected .append (pattern )
795+ if self .patterns_counts .loc [patterns_selected_ ].sum () / len (X ) > self .ratio_masked :
796+ break
779797
798+ df_mask .loc [self .df_isna [self .df_isna .isin (patterns_selected )].index ] = True
780799 return df_mask
781800
782801 def get_pattern (self , row : pd .Series ) -> str :
0 commit comments