@@ -699,110 +699,3 @@ def split(self, X: pd.DataFrame) -> List[pd.DataFrame]:
699699 list_masks .append (df_mask )
700700
701701 return list_masks
702-
703-
704- class PatternHoleGenerator (_HoleGenerator ):
705- """This class implements a way to generate holes in a dataframe.
706- The holes are generated from pattern, specified by the user.
707-
708- Parameters
709- ----------
710- n_splits : int
711- Number of splits
712- subset : Optional[List[str]], optional
713- Names of the columns for which holes must be created, by default None
714- ratio_masked : Optional[float], optional
715- Ratio of masked values to add, by default 0.05.
716- random_state : Optional[int], optional
717- The seed used by the random number generator, by default 42.
718- groups : List[str]
719- Names of the columns forming the groups, by default []
720- """
721-
722- def __init__ (
723- self ,
724- n_splits : int ,
725- ratio_masked : float = 0.05 ,
726- random_state : Union [None , int , np .random .RandomState ] = None ,
727- groups : List [str ] = [],
728- ):
729- super ().__init__ (
730- n_splits = n_splits ,
731- subset = None ,
732- random_state = random_state ,
733- ratio_masked = ratio_masked ,
734- groups = groups ,
735- )
736-
737- def fit (self , X : pd .DataFrame ) -> PatternHoleGenerator :
738- """Creare the groups based on the column names (groups attribute)
739-
740- Parameters
741- ----------
742- X : pd.DataFrame
743-
744- Returns
745- -------
746- PatternHoleGenerator
747- The model itself
748-
749- Raises
750- ------
751- if the number of samples/splits is greater than the number of groups.
752- """
753-
754- super ().fit (X )
755- df_isna = X .isna ().apply (lambda x : self .get_pattern (x ), axis = 1 ).to_frame (name = "pattern" )
756- self .df_isna = df_isna ["pattern" ]
757- self .patterns_counts = self .df_isna .value_counts ()
758- patterns = self .patterns_counts .index .to_list ()
759- if "_ALLNAN_" in patterns :
760- patterns .remove ("_ALLNAN_" )
761- if "_EMPTY_" in patterns :
762- patterns .remove ("_EMPTY_" )
763- self .patterns = patterns
764-
765- return self
766-
767- def generate_mask (self , X : pd .DataFrame ) -> pd .DataFrame :
768- """
769- Returns a mask for the dataframe at hand.
770-
771- Parameters
772- ----------
773- X : pd.DataFrame
774- Initial dataframe with a missing pattern to be imitated.
775- """
776- if self .ngroups is not None :
777- self .fit (X )
778-
779- df_mask = pd .DataFrame (False , index = X .index , columns = X .columns )
780- patterns_selected = ["_EMPTY_" ]
781- patterns = self .patterns
782- for k in range (len (self .patterns )):
783- pattern = random .choice (patterns )
784- patterns_selected_ = patterns_selected + [pattern ]
785- patterns .remove (pattern )
786- df_mask_ = df_mask .copy ()
787- X_ = X .copy ()
788-
789- df_mask_ .loc [self .df_isna [self .df_isna .isin (patterns_selected_ )].index ] = True
790- X_ [~ df_mask_ ] = np .nan
791- X_ = X_ .dropna (axis = 0 , how = "all" ).dropna (axis = 1 , how = "any" )
792- if X_ .size == 0 :
793- break
794- patterns_selected .append (pattern )
795- if self .patterns_counts .loc [patterns_selected_ ].sum () / len (X ) > self .ratio_masked :
796- break
797-
798- df_mask .loc [self .df_isna [self .df_isna .isin (patterns_selected )].index ] = True
799- return df_mask
800-
801- def get_pattern (self , row : pd .Series ) -> str :
802- list_col_pattern = [col for col in row .index .to_list () if row [col ] == True ]
803- if len (list_col_pattern ) == 0 :
804- return "_EMPTY_"
805- elif len (list_col_pattern ) == row .index .size :
806- return "_ALLNAN_"
807- else :
808- return "__" .join (list_col_pattern )
0 commit comments