11from __future__ import annotations
22
3+ import functools
34import logging
45from typing import List , Optional , Tuple
56
@@ -22,12 +23,26 @@ def compute_transition_matrix(states: pd.Series, ngroups: List = None):
2223 if ngroups is None :
2324 df_counts = compute_transition_counts_matrix (states )
2425 else :
25- df_counts = states .groupby (ngroups ).apply (compute_transition_counts_matrix ).sum ()
26+ list_counts = [compute_transition_counts_matrix (df ) for _ , df in states .groupby (ngroups )]
27+ df_counts = functools .reduce (lambda a , b : a .add (b , fill_value = 0 ), list_counts )
28+
2629 df_transition = df_counts .div (df_counts .sum (axis = 1 ), axis = 0 )
2730 return df_transition
2831
2932
30- class HoleGenerator :
33+ def get_sizes_max (values_isna : pd .Series ) -> pd .Series :
34+ ids_hole = (values_isna .diff () != 0 ).cumsum ()
35+ sizes_max = (
36+ values_isna .groupby (ids_hole )
37+ .apply (lambda x : (~ x ) * np .arange (len (x )))
38+ .shift (1 )
39+ .fillna (0 )
40+ .astype (int )
41+ )
42+ return sizes_max
43+
44+
45+ class _HoleGenerator :
3146 """
3247 This class implements a method to get indices of observed and missing values.
3348
@@ -59,7 +74,7 @@ def __init__(
5974 self .random_state = random_state
6075 self .groups = groups
6176
62- def fit (self , X : pd .DataFrame ) -> HoleGenerator :
77+ def fit (self , X : pd .DataFrame ) -> _HoleGenerator :
6378 """
6479 Fits the generator.
6580
@@ -118,7 +133,7 @@ def _check_subset(self, X: pd.DataFrame):
118133 )
119134
120135
121- class UniformHoleGenerator (HoleGenerator ):
136+ class UniformHoleGenerator (_HoleGenerator ):
122137 """This class implements a way to generate holes in a dataframe.
123138 The holes are generated randomly, using the resample method of scikit learn.
124139
@@ -176,8 +191,8 @@ def generate_mask(self, X: pd.DataFrame) -> pd.DataFrame:
176191 return df_mask
177192
178193
179- class SamplerHoleGenerator (HoleGenerator ):
180- """This class implements a way to generate holes in a dataframe.
194+ class _SamplerHoleGenerator (HoleGenerator ):
195+ """This abstract class implements a generic way to generate holes in a dataframe by sampling 1D hole size distributions .
181196
182197 Parameters
183198 ----------
@@ -250,14 +265,7 @@ def generate_mask(self, X: pd.DataFrame) -> pd.DataFrame:
250265 for column in self .subset :
251266 states = X [column ].isna ()
252267
253- ids_hole = (states .diff () != 0 ).cumsum ()
254- sizes_max = (
255- states .groupby (ids_hole )
256- .apply (lambda x : (~ x ) * np .arange (len (x )))
257- .shift (1 )
258- .fillna (0 )
259- .astype (int )
260- )
268+ sizes_max = get_sizes_max (states )
261269 n_masked_left = n_masked_col
262270
263271 sizes_sampled = self .generate_hole_sizes (column , n_masked_col , sort = True )
@@ -284,7 +292,7 @@ def generate_mask(self, X: pd.DataFrame) -> pd.DataFrame:
284292 return mask
285293
286294
287- class GeometricHoleGenerator (SamplerHoleGenerator ):
295+ class GeometricHoleGenerator (_SamplerHoleGenerator ):
288296 """This class implements a way to generate holes in a dataframe.
289297 The holes are generated following a Markov 1D process.
290298
@@ -353,7 +361,7 @@ def sample_sizes(self, column, n_masked):
353361 return sizes_sampled
354362
355363
356- class EmpiricalHoleGenerator (SamplerHoleGenerator ):
364+ class EmpiricalHoleGenerator (_SamplerHoleGenerator ):
357365 """This class implements a way to generate holes in a dataframe.
358366 The distribution of holes is learned from the data.
359367 The distributions are learned column by column.
@@ -447,7 +455,7 @@ def sample_sizes(self, column, n_masked):
447455 return sizes_sampled
448456
449457
450- class MultiMarkovHoleGenerator (HoleGenerator ):
458+ class MultiMarkovHoleGenerator (_HoleGenerator ):
451459 """This class implements a way to generate holes in a dataframe.
452460 The holes are generated according to a Markov process.
453461 Each line of the dataframe mask (np.nan) represents a state of the Markov chain.
@@ -500,7 +508,7 @@ def fit(self, X: pd.DataFrame) -> MultiMarkovHoleGenerator:
500508 The model itself
501509
502510 """
503- self . _check_subset (X )
511+ super (). fit (X )
504512
505513 states = X [self .subset ].isna ().apply (lambda x : tuple (x ), axis = 1 )
506514 self .df_transition = compute_transition_matrix (states , self .ngroups )
@@ -564,30 +572,38 @@ def generate_mask(self, X: pd.DataFrame) -> List[pd.DataFrame]:
564572 X_subset = X [self .subset ]
565573 mask = pd .DataFrame (False , columns = X_subset .columns , index = X_subset .index )
566574
567- mask_init = X_subset .isna ().any (axis = 1 )
568- n_masked = X [self .subset ].size * self .ratio_masked
575+ values_hasna = X_subset .isna ().any (axis = 1 )
576+
577+ sizes_max = get_sizes_max (values_hasna )
578+ n_masked_left = int (X [self .subset ].size * self .ratio_masked )
569579
570- realisations = self .generate_multi_realisation (n_masked )
580+ realisations = self .generate_multi_realisation (n_masked_left )
571581 realisations = sorted (realisations , reverse = True )
572582 for realisation in realisations :
573583 size_hole = len (realisation )
574- is_valid = (
575- ~ (mask_init | mask ).T .all ().rolling (size_hole + 2 ).max ().fillna (1 ).astype (bool )
576- )
577- if not np .any (is_valid ):
578- logger .warning (f"No place to introduce sampled hole of size { size_hole } !" )
579- continue
580- i_hole = np .random .choice (np .where (is_valid )[0 ])
584+ n_masked = sum ([sum (row ) for row in realisation ])
585+ size_hole = min (size_hole , sizes_max .max ())
586+ realisation = realisation [:size_hole ]
587+ i_hole = np .random .choice (np .where (size_hole <= sizes_max )[0 ])
588+ assert (~ mask .iloc [i_hole - size_hole : i_hole ]).all ().all ()
581589 mask .iloc [i_hole - size_hole : i_hole ] = mask .iloc [i_hole - size_hole : i_hole ].where (
582590 ~ np .array (realisation ), other = True
583591 )
592+ n_masked_left -= n_masked
593+
594+ sizes_max .iloc [i_hole - size_hole : i_hole ] = 0
595+ sizes_max .iloc [i_hole :] = np .minimum (
596+ sizes_max .iloc [i_hole :], np .arange (len (sizes_max .iloc [i_hole :]))
597+ )
598+ if n_masked_left <= 0 :
599+ break
584600
585601 complete_mask = pd .DataFrame (False , columns = X .columns , index = X .index )
586602 complete_mask [self .subset ] = mask [self .subset ]
587603 return mask
588604
589605
590- class GroupedHoleGenerator (HoleGenerator ):
606+ class GroupedHoleGenerator (_HoleGenerator ):
591607 """This class implements a way to generate holes in a dataframe.
592608 The holes are generated from groups, specified by the user.
593609 This class uses the GroupShuffleSplit function of sklearn.
0 commit comments