Skip to content

Commit f6a6646

Browse files
author
Rima Hajou
committed
Merge branch 'patch_multimarkov' into 'dev'
multimarkov patched See merge request quantmetry/retd/qolmat!9
2 parents 0bdc036 + 9ffea27 commit f6a6646

File tree

3 files changed

+52
-31
lines changed

3 files changed

+52
-31
lines changed

qolmat/benchmark/comparator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pandas as pd
55

66
from qolmat.benchmark import cross_validation, utils
7-
from qolmat.benchmark.missing_patterns import HoleGenerator
7+
from qolmat.benchmark.missing_patterns import _HoleGenerator
88

99

1010
class Comparator:
@@ -31,7 +31,7 @@ def __init__(
3131
self,
3232
dict_models: Dict,
3333
selected_columns: List[str],
34-
generator_holes: HoleGenerator,
34+
generator_holes: _HoleGenerator,
3535
columnwise_evaluation: Optional[bool] = True,
3636
search_params: Optional[Dict] = {},
3737
n_cv_calls: Optional[int] = 10,

qolmat/benchmark/missing_patterns.py

Lines changed: 45 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import functools
34
import logging
45
from typing import List, Optional, Tuple
56

@@ -22,12 +23,26 @@ def compute_transition_matrix(states: pd.Series, ngroups: List = None):
2223
if ngroups is None:
2324
df_counts = compute_transition_counts_matrix(states)
2425
else:
25-
df_counts = states.groupby(ngroups).apply(compute_transition_counts_matrix).sum()
26+
list_counts = [compute_transition_counts_matrix(df) for _, df in states.groupby(ngroups)]
27+
df_counts = functools.reduce(lambda a, b: a.add(b, fill_value=0), list_counts)
28+
2629
df_transition = df_counts.div(df_counts.sum(axis=1), axis=0)
2730
return df_transition
2831

2932

30-
class HoleGenerator:
33+
def get_sizes_max(values_isna: pd.Series) -> pd.Series:
34+
ids_hole = (values_isna.diff() != 0).cumsum()
35+
sizes_max = (
36+
values_isna.groupby(ids_hole)
37+
.apply(lambda x: (~x) * np.arange(len(x)))
38+
.shift(1)
39+
.fillna(0)
40+
.astype(int)
41+
)
42+
return sizes_max
43+
44+
45+
class _HoleGenerator:
3146
"""
3247
This class implements a method to get indices of observed and missing values.
3348
@@ -59,7 +74,7 @@ def __init__(
5974
self.random_state = random_state
6075
self.groups = groups
6176

62-
def fit(self, X: pd.DataFrame) -> HoleGenerator:
77+
def fit(self, X: pd.DataFrame) -> _HoleGenerator:
6378
"""
6479
Fits the generator.
6580
@@ -118,7 +133,7 @@ def _check_subset(self, X: pd.DataFrame):
118133
)
119134

120135

121-
class UniformHoleGenerator(HoleGenerator):
136+
class UniformHoleGenerator(_HoleGenerator):
122137
"""This class implements a way to generate holes in a dataframe.
123138
The holes are generated randomly, using the resample method of scikit learn.
124139
@@ -176,8 +191,8 @@ def generate_mask(self, X: pd.DataFrame) -> pd.DataFrame:
176191
return df_mask
177192

178193

179-
class SamplerHoleGenerator(HoleGenerator):
180-
"""This class implements a way to generate holes in a dataframe.
194+
class _SamplerHoleGenerator(HoleGenerator):
195+
"""This abstract class implements a generic way to generate holes in a dataframe by sampling 1D hole size distributions.
181196
182197
Parameters
183198
----------
@@ -250,14 +265,7 @@ def generate_mask(self, X: pd.DataFrame) -> pd.DataFrame:
250265
for column in self.subset:
251266
states = X[column].isna()
252267

253-
ids_hole = (states.diff() != 0).cumsum()
254-
sizes_max = (
255-
states.groupby(ids_hole)
256-
.apply(lambda x: (~x) * np.arange(len(x)))
257-
.shift(1)
258-
.fillna(0)
259-
.astype(int)
260-
)
268+
sizes_max = get_sizes_max(states)
261269
n_masked_left = n_masked_col
262270

263271
sizes_sampled = self.generate_hole_sizes(column, n_masked_col, sort=True)
@@ -284,7 +292,7 @@ def generate_mask(self, X: pd.DataFrame) -> pd.DataFrame:
284292
return mask
285293

286294

287-
class GeometricHoleGenerator(SamplerHoleGenerator):
295+
class GeometricHoleGenerator(_SamplerHoleGenerator):
288296
"""This class implements a way to generate holes in a dataframe.
289297
The holes are generated following a Markov 1D process.
290298
@@ -353,7 +361,7 @@ def sample_sizes(self, column, n_masked):
353361
return sizes_sampled
354362

355363

356-
class EmpiricalHoleGenerator(SamplerHoleGenerator):
364+
class EmpiricalHoleGenerator(_SamplerHoleGenerator):
357365
"""This class implements a way to generate holes in a dataframe.
358366
The distribution of holes is learned from the data.
359367
The distributions are learned column by column.
@@ -447,7 +455,7 @@ def sample_sizes(self, column, n_masked):
447455
return sizes_sampled
448456

449457

450-
class MultiMarkovHoleGenerator(HoleGenerator):
458+
class MultiMarkovHoleGenerator(_HoleGenerator):
451459
"""This class implements a way to generate holes in a dataframe.
452460
The holes are generated according to a Markov process.
453461
Each line of the dataframe mask (np.nan) represents a state of the Markov chain.
@@ -500,7 +508,7 @@ def fit(self, X: pd.DataFrame) -> MultiMarkovHoleGenerator:
500508
The model itself
501509
502510
"""
503-
self._check_subset(X)
511+
super().fit(X)
504512

505513
states = X[self.subset].isna().apply(lambda x: tuple(x), axis=1)
506514
self.df_transition = compute_transition_matrix(states, self.ngroups)
@@ -564,30 +572,38 @@ def generate_mask(self, X: pd.DataFrame) -> List[pd.DataFrame]:
564572
X_subset = X[self.subset]
565573
mask = pd.DataFrame(False, columns=X_subset.columns, index=X_subset.index)
566574

567-
mask_init = X_subset.isna().any(axis=1)
568-
n_masked = X[self.subset].size * self.ratio_masked
575+
values_hasna = X_subset.isna().any(axis=1)
576+
577+
sizes_max = get_sizes_max(values_hasna)
578+
n_masked_left = int(X[self.subset].size * self.ratio_masked)
569579

570-
realisations = self.generate_multi_realisation(n_masked)
580+
realisations = self.generate_multi_realisation(n_masked_left)
571581
realisations = sorted(realisations, reverse=True)
572582
for realisation in realisations:
573583
size_hole = len(realisation)
574-
is_valid = (
575-
~(mask_init | mask).T.all().rolling(size_hole + 2).max().fillna(1).astype(bool)
576-
)
577-
if not np.any(is_valid):
578-
logger.warning(f"No place to introduce sampled hole of size {size_hole}!")
579-
continue
580-
i_hole = np.random.choice(np.where(is_valid)[0])
584+
n_masked = sum([sum(row) for row in realisation])
585+
size_hole = min(size_hole, sizes_max.max())
586+
realisation = realisation[:size_hole]
587+
i_hole = np.random.choice(np.where(size_hole <= sizes_max)[0])
588+
assert (~mask.iloc[i_hole - size_hole : i_hole]).all().all()
581589
mask.iloc[i_hole - size_hole : i_hole] = mask.iloc[i_hole - size_hole : i_hole].where(
582590
~np.array(realisation), other=True
583591
)
592+
n_masked_left -= n_masked
593+
594+
sizes_max.iloc[i_hole - size_hole : i_hole] = 0
595+
sizes_max.iloc[i_hole:] = np.minimum(
596+
sizes_max.iloc[i_hole:], np.arange(len(sizes_max.iloc[i_hole:]))
597+
)
598+
if n_masked_left <= 0:
599+
break
584600

585601
complete_mask = pd.DataFrame(False, columns=X.columns, index=X.index)
586602
complete_mask[self.subset] = mask[self.subset]
587603
return mask
588604

589605

590-
class GroupedHoleGenerator(HoleGenerator):
606+
class GroupedHoleGenerator(_HoleGenerator):
591607
"""This class implements a way to generate holes in a dataframe.
592608
The holes are generated from groups, specified by the user.
593609
This class uses the GroupShuffleSplit function of sklearn.

qolmat/notebooks/benchmark.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,13 +185,18 @@ This allows an easy comparison of the different imputations.
185185

186186
Note these metrics compute reconstruction errors; it tells nothing about the distances between the "true" and "imputed" distributions.
187187

188+
```python
189+
missing_patterns.EmpiricalHoleGenerator(n_splits=2, groups=["station"], ratio_masked=0.1)
190+
```
191+
188192
```python
189193
doy = pd.Series(df_data.reset_index().datetime.dt.isocalendar().week.values, index=df_data.index)
190194

191195
generator_holes = missing_patterns.EmpiricalHoleGenerator(n_splits=2, groups=["station"], ratio_masked=0.1)
192196
# generator_holes = missing_patterns.GeometricHoleGenerator(n_splits=10, groups=["station"], ratio_masked=0.1)
193197
# generator_holes = missing_patterns.UniformHoleGenerator(n_splits=2, ratio_masked=0.4)
194198
# generator_holes = missing_patterns.GroupedHoleGenerator(n_splits=2, groups=["station", doy], ratio_masked=0.4)
199+
# generator_holes = missing_patterns.MultiMarkovHoleGenerator(n_splits=2, groups=["station"], ratio_masked=0.1)
195200

196201
comparison = comparator.Comparator(
197202
dict_models,

0 commit comments

Comments
 (0)