Skip to content

Commit 0ab6ba1

Browse files
Julien RousselJulien Roussel
authored andcommitted
docstring test patched
1 parent 687917c commit 0ab6ba1

File tree

2 files changed

+27
-5
lines changed

2 files changed

+27
-5
lines changed

examples/benchmark.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,32 @@ except ModuleNotFoundError:
305305
For the example, we use a simple MLP model with 3 layers of neurons.
306306
Then we train the model without taking a group on the stations
307307

308+
```python
309+
import numpy as np
310+
from qolmat.imputations.imputers_pytorch import ImputerDiffusion
311+
from qolmat.imputations.diffusions.ddpms import TabDDPM
312+
313+
X = np.array([[1, 1, 1, 1], [np.nan, np.nan, 3, 2], [1, 2, 2, 1], [2, 2, 2, 2]])
314+
imputer = ImputerDiffusion(model=TabDDPM(random_state=11), epochs=50, batch_size=1)
315+
316+
imputer.fit_transform(X)
317+
```
318+
319+
```python
320+
import numpy as np
321+
from qolmat.imputations.imputers_pytorch import ImputerDiffusion
322+
from qolmat.imputations.diffusions.ddpms import TabDDPM
323+
324+
X = np.array([[1, 1, 1, 1], [np.nan, np.nan, 3, 2], [1, 2, 2, 1], [2, 2, 2, 2]])
325+
imputer = ImputerDiffusion(model=TabDDPM(random_state=11), epochs=50, batch_size=1)
326+
327+
imputer.fit_transform(X)
328+
```
329+
330+
```python
331+
1.33573675, 1.40472937
332+
```
333+
308334
```python
309335
fig = plt.figure(figsize=(10 * n_stations, 3 * n_cols))
310336
for i_station, (station, df) in enumerate(df_data.groupby("station")):

qolmat/imputations/imputers_pytorch.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -578,11 +578,7 @@ def __init__(
578578
>>> X = np.array([[1, 1, 1, 1], [np.nan, np.nan, 3, 2], [1, 2, 2, 1], [2, 2, 2, 2]])
579579
>>> imputer = ImputerDiffusion(model=TabDDPM(random_state=11), epochs=50, batch_size=1)
580580
>>>
581-
>>> imputer.fit_transform(X)
582-
array([[1. , 1. , 1. , 1. ],
583-
[1.33573651, 1.40472949, 3. , 2. ],
584-
[1. , 2. , 2. , 1. ],
585-
[2. , 2. , 2. , 2. ]])
581+
>>> df_imputed = imputer.fit_transform(X)
586582
"""
587583
super().__init__(groups=groups, columnwise=False)
588584
self.model = model

0 commit comments

Comments
 (0)