2121import matplotlib .pyplot as plt
2222import numpy as np
2323import pandas as pd
24+ from sklearn import utils as sku
25+ from torch import rand
2426
2527from qolmat .benchmark import missing_patterns
2628from qolmat .utils import data
2729
30+ seed = 1234
31+ rng = sku .check_random_state (seed )
32+
2833# %%
2934# 1. Data
3035# ---------------------------------------------------------------
4247columns = ["TEMP" , "PRES" , "DEWP" , "RAIN" , "WSPM" ]
4348df_data = df_data [columns ]
4449
45- df = data .add_holes (df_data , ratio_masked = 0.2 , mean_size = 120 )
50+ df = data .add_holes (df_data , ratio_masked = 0.2 , mean_size = 120 , random_state = rng )
4651cols_to_impute = df .columns
4752
4853# %%
@@ -169,8 +174,8 @@ def plot_cdf(
169174 axs [ind ].plot (sorted_data , cdf , c = "gray" , lw = 2 , label = "original" )
170175
171176 for df_mask , label , color in zip (list_df_mask , labels , colors ):
172- array_mask = df_mask .copy ()
173- array_mask [array_mask == True ] = np .nan
177+ array_mask = df_mask .astype ( float ). copy ()
178+ array_mask [df_mask ] = np .nan
174179 hole_sizes_created = get_holes_sizes_column_wise (array_mask .to_numpy ())
175180
176181 for ind , (hole_created , col ) in enumerate (
@@ -197,7 +202,7 @@ def plot_cdf(
197202# Note this class is more suited for tabular datasets.
198203
199204uniform_generator = missing_patterns .UniformHoleGenerator (
200- n_splits = 1 , subset = df .columns , ratio_masked = 0.1
205+ n_splits = 1 , subset = df .columns , ratio_masked = 0.1 , random_state = rng
201206)
202207uniform_mask = uniform_generator .split (df )[0 ]
203208
@@ -223,7 +228,7 @@ def plot_cdf(
223228# :class:`~qolmat.benchmark.missing_patterns.UniformHoleGenerator` class.
224229
225230geometric_generator = missing_patterns .GeometricHoleGenerator (
226- n_splits = 1 , subset = cols_to_impute , ratio_masked = 0.1
231+ n_splits = 1 , subset = cols_to_impute , ratio_masked = 0.1 , random_state = rng
227232)
228233geometric_mask = geometric_generator .split (df )[0 ]
229234
@@ -249,7 +254,7 @@ def plot_cdf(
249254# is learned on each group: here on each station.
250255
251256empirical_generator = missing_patterns .EmpiricalHoleGenerator (
252- n_splits = 1 , subset = df .columns , ratio_masked = 0.1 , groups = ("station" ,)
257+ n_splits = 1 , subset = df .columns , ratio_masked = 0.1 , groups = ("station" ,), random_state = rng
253258)
254259empirical_mask = empirical_generator .split (df )[0 ]
255260
@@ -274,7 +279,7 @@ def plot_cdf(
274279# :class:`~qolmat.benchmark.missing_patterns.MultiMarkovHoleGenerator` class.
275280
276281multi_markov_generator = missing_patterns .MultiMarkovHoleGenerator (
277- n_splits = 1 , subset = df .columns , ratio_masked = 0.1
282+ n_splits = 1 , subset = df .columns , ratio_masked = 0.1 , random_state = rng
278283)
279284multi_markov_mask = multi_markov_generator .split (df )[0 ]
280285
@@ -297,7 +302,7 @@ def plot_cdf(
297302# :class:`~qolmat.benchmark.missing_patterns.GroupedHoleGenerator` class.
298303
299304grouped_generator = missing_patterns .GroupedHoleGenerator (
300- n_splits = 1 , subset = df .columns , ratio_masked = 0.1 , groups = ("station" ,)
305+ n_splits = 1 , subset = df .columns , ratio_masked = 0.1 , groups = ("station" ,), random_state = rng
301306)
302307grouped_mask = grouped_generator .split (df )[0 ]
303308
0 commit comments