Skip to content

Commit a901097

Browse files
Julien RousselJulien Roussel
authored andcommitted
random state fixed for DDPM
1 parent 210e2f4 commit a901097

File tree

3 files changed

+38
-27
lines changed

3 files changed

+38
-27
lines changed

qolmat/analysis/holes_characterization.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@ class LittleTest(McarTest):
3434
imputer : Optional[ImputerEM]
3535
Imputer based on the EM algorithm. The 'model' attribute must be equal to 'multinormal'.
3636
If None, the default ImputerEM is taken.
37-
random_state : Union[None, int, np.random.RandomState], optional
38-
Controls the randomness of the fit_transform, by default None
37+
random_state : int, RandomState instance or None, default=None
38+
Controls the randomness.
39+
Pass an int for reproducible output across multiple function calls.
3940
"""
4041

4142
def __init__(

qolmat/benchmark/missing_patterns.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,9 @@ class _HoleGenerator:
6363
Names of the columns for which holes must be created, by default None
6464
ratio_masked : Optional[float]
6565
Ratio of values ​​to mask, by default 0.05.
66-
random_state : Optional[int]
67-
The seed used by the random number generator, by default 42.
66+
random_state : int, RandomState instance or None, default=None
67+
Controls the randomness.
68+
Pass an int for reproducible output across multiple function calls.
6869
groups: Tuple[str, ...]
6970
Column names used to group the data
7071
"""
@@ -150,8 +151,9 @@ class UniformHoleGenerator(_HoleGenerator):
150151
Names of the columns for which holes must be created, by default None
151152
ratio_masked : Optional[float], optional
152153
Ratio of masked values ​​to add, by default 0.05.
153-
random_state : Optional[int], optional
154-
The seed used by the random number generator, by default 42.
154+
random_state : int, RandomState instance or None, default=None
155+
Controls the randomness.
156+
Pass an int for reproducible output across multiple function calls.
155157
sample_proportional: bool, optional
156158
If True, generates holes in target columns with same equal frequency.
157159
If False, reproduces the empirical proportions between the variables.
@@ -215,8 +217,9 @@ class _SamplerHoleGenerator(_HoleGenerator):
215217
Names of the columns for which holes must be created, by default None
216218
ratio_masked : Optional[float], optional
217219
Ratio of masked values ​​to add, by default 0.05.
218-
random_state : Optional[int], optional
219-
The seed used by the random number generator, by default 42.
220+
random_state : int, RandomState instance or None, default=None
221+
Controls the randomness.
222+
Pass an int for reproducible output across multiple function calls.
220223
groups: Tuple[str, ...]
221224
Column names used to group the data
222225
"""
@@ -321,8 +324,9 @@ class GeometricHoleGenerator(_SamplerHoleGenerator):
321324
Names of the columns for which holes must be created, by default None
322325
ratio_masked : Optional[float], optional
323326
Ratio of masked values ​​to add, by default 0.05.
324-
random_state : Union[None, int, np.random.RandomState], optional
325-
The seed used by the random number generator, by default 42.
327+
random_state : int, RandomState instance or None, default=None
328+
Controls the randomness.
329+
Pass an int for reproducible output across multiple function calls.
326330
groups: Tuple[str, ...]
327331
Column names used to group the data
328332
"""
@@ -390,8 +394,9 @@ class EmpiricalHoleGenerator(_SamplerHoleGenerator):
390394
Names of the columns for which holes must be created, by default None
391395
ratio_masked : Optional[float], optional
392396
Ratio of masked values ​​to add, by default 0.05.
393-
random_state : Optional[int], optional
394-
The seed used by the random number generator, by default 42.
397+
random_state : int, RandomState instance or None, default=None
398+
Controls the randomness.
399+
Pass an int for reproducible output across multiple function calls.
395400
groups: Tuple[str, ...]
396401
Column names used to group the data
397402
"""
@@ -485,8 +490,9 @@ class MultiMarkovHoleGenerator(_HoleGenerator):
485490
Names of the columns for which holes must be created, by default None
486491
ratio_masked : Optional[float], optional
487492
Ratio of masked values to add, by default 0.05
488-
random_state : Optional[int], optional
489-
The seed used by the random number generator, by default 42.
493+
random_state : int, RandomState instance or None, default=None
494+
Controls the randomness.
495+
Pass an int for reproducible output across multiple function calls.
490496
groups: Tuple[str, ...]
491497
Column names used to group the data
492498
"""
@@ -634,8 +640,9 @@ class GroupedHoleGenerator(_HoleGenerator):
634640
Names of the columns for which holes must be created, by default None
635641
ratio_masked : Optional[float], optional
636642
Ratio of masked to add, by default 0.05
637-
random_state : Optional[int], optional
638-
The seed used by the random number generator, by default 42.
643+
random_state : int, RandomState instance or None, default=None
644+
Controls the randomness.
645+
Pass an int for reproducible output across multiple function calls.
639646
groups : Tuple[str, ...]
640647
Names of the columns forming the groups, by default []
641648
"""

qolmat/imputations/diffusions/ddpms.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
from typing import Dict, List, Callable, Tuple, Union
22
from typing_extensions import Self
3-
import math
3+
import sys
44
import numpy as np
55
import pandas as pd
66
import time
77
from datetime import timedelta
88
from tqdm import tqdm
9-
import gc
109

1110
import torch
1211
from torch.utils.data import DataLoader, TensorDataset
1312
from sklearn import preprocessing
13+
from sklearn import utils as sku
14+
1415

1516
from qolmat.imputations.diffusions.base import AutoEncoder, ResidualBlock, ResidualBlockTS
1617
from qolmat.imputations.diffusions.utils import get_num_params
@@ -39,7 +40,7 @@ def __init__(
3940
p_dropout: float = 0.0,
4041
num_sampling: int = 1,
4142
is_clip: bool = True,
42-
random_state: Union[None, int] = None,
43+
random_state: Union[None, int, np.random.RandomState] = None,
4344
):
4445
"""Diffusion model for tabular data based on
4546
Denoising Diffusion Probabilistic Models (DDPM) of
@@ -69,8 +70,9 @@ def __init__(
6970
Dropout probability, by default 0.0
7071
num_sampling : int, optional
7172
Number of samples generated for each cell, by default 1
72-
random_state : int, optional
73-
The seed of the pseudo random number generator to use, for reproductibility.
73+
random_state : int, RandomState instance or None, default=None
74+
Controls the randomness.
75+
Pass an int for reproducible output across multiple function calls.
7476
"""
7577
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
7678

@@ -111,8 +113,9 @@ def __init__(
111113
self.is_clip = is_clip
112114

113115
self.normalizer_x = preprocessing.StandardScaler()
114-
if random_state is not None:
115-
torch.manual_seed(random_state)
116+
self.random_state = sku.check_random_state(random_state)
117+
seed_torch = self.random_state.randint(sys.maxsize)
118+
torch.manual_seed(seed_torch)
116119

117120
def _q_sample(self, x: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
118121
"""Section 3.2, algorithm 1 formula implementation. Forward process, defined by `q`.
@@ -350,7 +353,6 @@ def fit(
350353
round: int = 10,
351354
cols_imputed: Tuple[str, ...] = (),
352355
) -> Self:
353-
354356
"""Fit data
355357
356358
Parameters
@@ -542,7 +544,7 @@ def __init__(
542544
p_dropout: float = 0.0,
543545
num_sampling: int = 1,
544546
is_rolling: bool = False,
545-
random_state: Union[None, int] = None,
547+
random_state: Union[None, int, np.random.RandomState] = None,
546548
):
547549
"""Diffusion model for time-series data based on the works of
548550
Ho et al., 2020 (https://arxiv.org/abs/2006.11239),
@@ -581,8 +583,9 @@ def __init__(
581583
Number of samples generated for each cell, by default 1
582584
is_rolling : bool, optional
583585
Use pandas.DataFrame.rolling for preprocessing data, by default False
584-
random_state : int, optional
585-
The seed of the pseudo random number generator to use, for reproductibility.
586+
random_state : int, RandomState instance or None, default=None
587+
Controls the randomness.
588+
Pass an int for reproducible output across multiple function calls.
586589
"""
587590
super().__init__(
588591
num_noise_steps,

0 commit comments

Comments
 (0)