@@ -69,6 +69,8 @@ def __init__(
6969 Dropout probability, by default 0.0
7070 num_sampling : int, optional
7171 Number of samples generated for each cell, by default 1
72+ random_state : int, optional
73+ The seed of the pseudo random number generator to use, for reproductibility.
7274 """
7375 self .device = torch .device ("cuda" ) if torch .cuda .is_available () else torch .device ("cpu" )
7476
@@ -540,6 +542,7 @@ def __init__(
540542 p_dropout : float = 0.0 ,
541543 num_sampling : int = 1 ,
542544 is_rolling : bool = False ,
545+ random_state : Union [None , int ] = None ,
543546 ):
544547 """Diffusion model for time-series data based on the works of
545548 Ho et al., 2020 (https://arxiv.org/abs/2006.11239),
@@ -578,6 +581,8 @@ def __init__(
578581 Number of samples generated for each cell, by default 1
579582 is_rolling : bool, optional
580583 Use pandas.DataFrame.rolling for preprocessing data, by default False
584+ random_state : int, optional
585+ The seed of the pseudo random number generator to use, for reproductibility.
581586 """
582587 super ().__init__ (
583588 num_noise_steps ,
@@ -589,6 +594,7 @@ def __init__(
589594 num_blocks ,
590595 p_dropout ,
591596 num_sampling ,
597+ random_state = random_state ,
592598 )
593599
594600 self .dim_feedforward = dim_feedforward
0 commit comments