Skip to content

Commit e3ae559

Browse files
Merge pull request #150 from scikit-learn-contrib/chp_add_rand_state_ddpm
Chp add rand state ddpm
2 parents 47565ff + bb53bae commit e3ae559

File tree

6 files changed

+77
-20
lines changed

6 files changed

+77
-20
lines changed

examples/benchmark.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,32 @@ except ModuleNotFoundError:
305305
For the example, we use a simple MLP model with 3 layers of neurons.
306306
Then we train the model without taking a group on the stations
307307

308+
```python
309+
import numpy as np
310+
from qolmat.imputations.imputers_pytorch import ImputerDiffusion
311+
from qolmat.imputations.diffusions.ddpms import TabDDPM
312+
313+
X = np.array([[1, 1, 1, 1], [np.nan, np.nan, 3, 2], [1, 2, 2, 1], [2, 2, 2, 2]])
314+
imputer = ImputerDiffusion(model=TabDDPM(random_state=11), epochs=50, batch_size=1)
315+
316+
imputer.fit_transform(X)
317+
```
318+
319+
```python
320+
import numpy as np
321+
from qolmat.imputations.imputers_pytorch import ImputerDiffusion
322+
from qolmat.imputations.diffusions.ddpms import TabDDPM
323+
324+
X = np.array([[1, 1, 1, 1], [np.nan, np.nan, 3, 2], [1, 2, 2, 1], [2, 2, 2, 2]])
325+
imputer = ImputerDiffusion(model=TabDDPM(random_state=11), epochs=50, batch_size=1)
326+
327+
imputer.fit_transform(X)
328+
```
329+
330+
```python
331+
1.33573675, 1.40472937
332+
```
333+
308334
```python
309335
fig = plt.figure(figsize=(10 * n_stations, 3 * n_cols))
310336
for i_station, (station, df) in enumerate(df_data.groupby("station")):

qolmat/analysis/__init__.py

Whitespace-only changes.

qolmat/analysis/holes_characterization.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@ class LittleTest(McarTest):
3434
imputer : Optional[ImputerEM]
3535
Imputer based on the EM algorithm. The 'model' attribute must be equal to 'multinormal'.
3636
If None, the default ImputerEM is taken.
37-
random_state : Union[None, int, np.random.RandomState], optional
38-
Controls the randomness of the fit_transform, by default None
37+
random_state : int, RandomState instance or None, default=None
38+
Controls the randomness.
39+
Pass an int for reproducible output across multiple function calls.
3940
"""
4041

4142
def __init__(

qolmat/benchmark/missing_patterns.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,9 @@ class _HoleGenerator:
6363
Names of the columns for which holes must be created, by default None
6464
ratio_masked : Optional[float]
6565
Ratio of values ​​to mask, by default 0.05.
66-
random_state : Optional[int]
67-
The seed used by the random number generator, by default 42.
66+
random_state : int, RandomState instance or None, default=None
67+
Controls the randomness.
68+
Pass an int for reproducible output across multiple function calls.
6869
groups: Tuple[str, ...]
6970
Column names used to group the data
7071
"""
@@ -150,8 +151,9 @@ class UniformHoleGenerator(_HoleGenerator):
150151
Names of the columns for which holes must be created, by default None
151152
ratio_masked : Optional[float], optional
152153
Ratio of masked values ​​to add, by default 0.05.
153-
random_state : Optional[int], optional
154-
The seed used by the random number generator, by default 42.
154+
random_state : int, RandomState instance or None, default=None
155+
Controls the randomness.
156+
Pass an int for reproducible output across multiple function calls.
155157
sample_proportional: bool, optional
156158
If True, generates holes in target columns with same equal frequency.
157159
If False, reproduces the empirical proportions between the variables.
@@ -215,8 +217,9 @@ class _SamplerHoleGenerator(_HoleGenerator):
215217
Names of the columns for which holes must be created, by default None
216218
ratio_masked : Optional[float], optional
217219
Ratio of masked values ​​to add, by default 0.05.
218-
random_state : Optional[int], optional
219-
The seed used by the random number generator, by default 42.
220+
random_state : int, RandomState instance or None, default=None
221+
Controls the randomness.
222+
Pass an int for reproducible output across multiple function calls.
220223
groups: Tuple[str, ...]
221224
Column names used to group the data
222225
"""
@@ -321,8 +324,9 @@ class GeometricHoleGenerator(_SamplerHoleGenerator):
321324
Names of the columns for which holes must be created, by default None
322325
ratio_masked : Optional[float], optional
323326
Ratio of masked values ​​to add, by default 0.05.
324-
random_state : Union[None, int, np.random.RandomState], optional
325-
The seed used by the random number generator, by default 42.
327+
random_state : int, RandomState instance or None, default=None
328+
Controls the randomness.
329+
Pass an int for reproducible output across multiple function calls.
326330
groups: Tuple[str, ...]
327331
Column names used to group the data
328332
"""
@@ -390,8 +394,9 @@ class EmpiricalHoleGenerator(_SamplerHoleGenerator):
390394
Names of the columns for which holes must be created, by default None
391395
ratio_masked : Optional[float], optional
392396
Ratio of masked values ​​to add, by default 0.05.
393-
random_state : Optional[int], optional
394-
The seed used by the random number generator, by default 42.
397+
random_state : int, RandomState instance or None, default=None
398+
Controls the randomness.
399+
Pass an int for reproducible output across multiple function calls.
395400
groups: Tuple[str, ...]
396401
Column names used to group the data
397402
"""
@@ -485,8 +490,9 @@ class MultiMarkovHoleGenerator(_HoleGenerator):
485490
Names of the columns for which holes must be created, by default None
486491
ratio_masked : Optional[float], optional
487492
Ratio of masked values to add, by default 0.05
488-
random_state : Optional[int], optional
489-
The seed used by the random number generator, by default 42.
493+
random_state : int, RandomState instance or None, default=None
494+
Controls the randomness.
495+
Pass an int for reproducible output across multiple function calls.
490496
groups: Tuple[str, ...]
491497
Column names used to group the data
492498
"""
@@ -634,8 +640,9 @@ class GroupedHoleGenerator(_HoleGenerator):
634640
Names of the columns for which holes must be created, by default None
635641
ratio_masked : Optional[float], optional
636642
Ratio of masked to add, by default 0.05
637-
random_state : Optional[int], optional
638-
The seed used by the random number generator, by default 42.
643+
random_state : int, RandomState instance or None, default=None
644+
Controls the randomness.
645+
Pass an int for reproducible output across multiple function calls.
639646
groups : Tuple[str, ...]
640647
Names of the columns forming the groups, by default []
641648
"""

qolmat/imputations/diffusions/ddpms.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
1-
from typing import Dict, List, Callable, Tuple
1+
from typing import Dict, List, Callable, Tuple, Union
22
from typing_extensions import Self
3-
import math
3+
import sys
44
import numpy as np
55
import pandas as pd
66
import time
77
from datetime import timedelta
88
from tqdm import tqdm
9-
import gc
109

1110
import torch
1211
from torch.utils.data import DataLoader, TensorDataset
1312
from sklearn import preprocessing
13+
from sklearn import utils as sku
14+
1415

1516
from qolmat.imputations.diffusions.base import AutoEncoder, ResidualBlock, ResidualBlockTS
1617
from qolmat.imputations.diffusions.utils import get_num_params
@@ -39,6 +40,7 @@ def __init__(
3940
p_dropout: float = 0.0,
4041
num_sampling: int = 1,
4142
is_clip: bool = True,
43+
random_state: Union[None, int, np.random.RandomState] = None,
4244
):
4345
"""Diffusion model for tabular data based on
4446
Denoising Diffusion Probabilistic Models (DDPM) of
@@ -68,6 +70,9 @@ def __init__(
6870
Dropout probability, by default 0.0
6971
num_sampling : int, optional
7072
Number of samples generated for each cell, by default 1
73+
random_state : int, RandomState instance or None, default=None
74+
Controls the randomness.
75+
Pass an int for reproducible output across multiple function calls.
7176
"""
7277
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
7378

@@ -108,6 +113,9 @@ def __init__(
108113
self.is_clip = is_clip
109114

110115
self.normalizer_x = preprocessing.StandardScaler()
116+
self.random_state = sku.check_random_state(random_state)
117+
seed_torch = self.random_state.randint(2**31 - 1)
118+
torch.manual_seed(seed_torch)
111119

112120
def _q_sample(self, x: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
113121
"""Section 3.2, algorithm 1 formula implementation. Forward process, defined by `q`.
@@ -345,7 +353,6 @@ def fit(
345353
round: int = 10,
346354
cols_imputed: Tuple[str, ...] = (),
347355
) -> Self:
348-
349356
"""Fit data
350357
351358
Parameters
@@ -537,6 +544,7 @@ def __init__(
537544
p_dropout: float = 0.0,
538545
num_sampling: int = 1,
539546
is_rolling: bool = False,
547+
random_state: Union[None, int, np.random.RandomState] = None,
540548
):
541549
"""Diffusion model for time-series data based on the works of
542550
Ho et al., 2020 (https://arxiv.org/abs/2006.11239),
@@ -575,6 +583,9 @@ def __init__(
575583
Number of samples generated for each cell, by default 1
576584
is_rolling : bool, optional
577585
Use pandas.DataFrame.rolling for preprocessing data, by default False
586+
random_state : int, RandomState instance or None, default=None
587+
Controls the randomness.
588+
Pass an int for reproducible output across multiple function calls.
578589
"""
579590
super().__init__(
580591
num_noise_steps,
@@ -586,6 +597,7 @@ def __init__(
586597
num_blocks,
587598
p_dropout,
588599
num_sampling,
600+
random_state=random_state,
589601
)
590602

591603
self.dim_feedforward = dim_feedforward

qolmat/imputations/imputers_pytorch.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,17 @@ def __init__(
568568
freq_str : str
569569
Frequency string of DateOffset of Pandas.
570570
It is for processing time-series data, used in diffusion models e.g., TsDDPM.
571+
572+
Examples
573+
--------
574+
>>> import numpy as np
575+
>>> from qolmat.imputations.imputers_pytorch import ImputerDiffusion
576+
>>> from qolmat.imputations.diffusions.ddpms import TabDDPM
577+
>>>
578+
>>> X = np.array([[1, 1, 1, 1], [np.nan, np.nan, 3, 2], [1, 2, 2, 1], [2, 2, 2, 2]])
579+
>>> imputer = ImputerDiffusion(model=TabDDPM(random_state=11), epochs=50, batch_size=1)
580+
>>>
581+
>>> df_imputed = imputer.fit_transform(X)
571582
"""
572583
super().__init__(groups=groups, columnwise=False)
573584
self.model = model

0 commit comments

Comments
 (0)