Skip to content

Commit 4dcd07d

Browse files
committed
propagate random_state param in TsDDPM
1 parent 4156b87 commit 4dcd07d

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

qolmat/imputations/diffusions/ddpms.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ def __init__(
6969
Dropout probability, by default 0.0
7070
num_sampling : int, optional
7171
Number of samples generated for each cell, by default 1
72+
random_state : int, optional
73+
The seed of the pseudo random number generator to use, for reproductibility.
7274
"""
7375
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
7476

@@ -540,6 +542,7 @@ def __init__(
540542
p_dropout: float = 0.0,
541543
num_sampling: int = 1,
542544
is_rolling: bool = False,
545+
random_state: Union[None, int] = None,
543546
):
544547
"""Diffusion model for time-series data based on the works of
545548
Ho et al., 2020 (https://arxiv.org/abs/2006.11239),
@@ -578,6 +581,8 @@ def __init__(
578581
Number of samples generated for each cell, by default 1
579582
is_rolling : bool, optional
580583
Use pandas.DataFrame.rolling for preprocessing data, by default False
584+
random_state : int, optional
585+
The seed of the pseudo random number generator to use, for reproductibility.
581586
"""
582587
super().__init__(
583588
num_noise_steps,
@@ -589,6 +594,7 @@ def __init__(
589594
num_blocks,
590595
p_dropout,
591596
num_sampling,
597+
random_state=random_state,
592598
)
593599

594600
self.dim_feedforward = dim_feedforward

0 commit comments

Comments
 (0)