Skip to content

Commit 15510c6

Browse files
Julien RousselJulien Roussel
authored andcommitted
merged
2 parents ef683c8 + e84b4e8 commit 15510c6

File tree

23 files changed

+1030
-517
lines changed

23 files changed

+1030
-517
lines changed

.github/workflows/test.yml

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
strategy:
1818
matrix:
1919
os: [ubuntu-latest, windows-latest]
20-
python-version: ["3.8", "3.9", "3.10", "3.11"]
20+
python-version: ["3.9", "3.11", "3.12"]
2121
defaults:
2222
run:
2323
shell: bash -l {0}
@@ -28,10 +28,10 @@ jobs:
2828
run: |
2929
if [[ "${GITHUB_REF}" == "refs/heads/main" || "${GITHUB_REF}" == "refs/heads/dev" ]]; then
3030
echo "os-matrix=ubuntu-latest,windows-latest" >> $GITHUB_ENV
31-
echo "python-matrix=3.8,3.9,3.10,3.11" >> $GITHUB_ENV
31+
echo "python-matrix=3.9,3.11,3.12" >> $GITHUB_ENV
3232
else
3333
echo "os-matrix=ubuntu-latest" >> $GITHUB_ENV
34-
echo "python-matrix=3.11" >> $GITHUB_ENV
34+
echo "python-matrix=3.12" >> $GITHUB_ENV
3535
fi
3636
- name: Checkout
3737
uses: actions/checkout@v3
@@ -62,35 +62,6 @@ jobs:
6262
uses: codecov/codecov-action@v3
6363
env:
6464
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
65-
66-
docs:
67-
runs-on: ubuntu-latest
68-
needs: check
69-
70-
steps:
71-
- name: Checkout
72-
uses: actions/checkout@v3
73-
- name: Python
74-
uses: actions/setup-python@v4
75-
with:
76-
python-version: ${{ matrix.python-version }}
77-
- name: Cache Poetry
78-
uses: actions/cache@v3
79-
with:
80-
path: |
81-
~/.cache/pypoetry
82-
~/.cache/pip
83-
key: ${{ runner.os }}-poetry-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
84-
restore-keys: |
85-
${{ runner.os }}-poetry-${{ matrix.python-version }}-
86-
- name: Poetry
87-
uses: snok/install-poetry@v1
88-
with:
89-
version: 1.8.3
90-
- name: Lock
91-
run: poetry lock --no-update
92-
- name: Install
93-
run: poetry install
9465
- name: Check Changed Files
9566
id: changed-files
9667
run: |

.python-version

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
env_qolmat_3.9

docs/index.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@
2222
.. toctree::
2323
:maxdepth: 2
2424
:hidden:
25-
:caption: API
25+
:caption: ANALYSIS
2626

27-
api
27+
analysis
28+
examples/tutorials/plot_tuto_mcar
2829

2930
.. toctree::
3031
:maxdepth: 2
3132
:hidden:
32-
:caption: ANALYSIS
33+
:caption: API
3334

34-
analysis
35-
examples/tutorials/plot_tuto_mcar
35+
api

examples/benchmark.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ from qolmat.imputations.imputers_pytorch import ImputerDiffusion
311311
from qolmat.imputations.diffusions.ddpms import TabDDPM
312312

313313
X = np.array([[1, 1, 1, 1], [np.nan, np.nan, 3, 2], [1, 2, 2, 1], [2, 2, 2, 2]])
314-
imputer = ImputerDiffusion(model=TabDDPM(random_state=11), epochs=50, batch_size=1)
314+
imputer = ImputerDiffusion(epochs=50, batch_size=1, random_state=11)
315315

316316
imputer.fit_transform(X)
317317
```
@@ -322,7 +322,7 @@ from qolmat.imputations.imputers_pytorch import ImputerDiffusion
322322
from qolmat.imputations.diffusions.ddpms import TabDDPM
323323

324324
X = np.array([[1, 1, 1, 1], [np.nan, np.nan, 3, 2], [1, 2, 2, 1], [2, 2, 2, 2]])
325-
imputer = ImputerDiffusion(model=TabDDPM(random_state=11), epochs=50, batch_size=1)
325+
imputer = ImputerDiffusion(epochs=50, batch_size=1, random_state=11)
326326

327327
imputer.fit_transform(X)
328328
```
@@ -358,7 +358,7 @@ encoder, decoder = imputers_pytorch.build_autoencoder(input_dim=n_variables,lat
358358
```python
359359
dict_imputers["MLP"] = imputer_mlp = imputers_pytorch.ImputerRegressorPyTorch(estimator=estimator, groups=('station',), epochs=500)
360360
dict_imputers["Autoencoder"] = imputer_autoencoder = imputers_pytorch.ImputerAutoencoder(encoder, decoder, max_iterations=100, epochs=100)
361-
dict_imputers["Diffusion"] = imputer_diffusion = imputers_pytorch.ImputerDiffusion(model=TabDDPM(num_sampling=5), epochs=100, batch_size=100)
361+
dict_imputers["Diffusion"] = imputer_diffusion = imputers_pytorch.ImputerDiffusion(epochs=100, batch_size=100, num_sampling=5)
362362
```
363363

364364
We can re-run the imputation model benchmark as before.

examples/tutorials/plot_tuto_diffusion_models.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@
7171
df_data_valid = df_data.iloc[:500]
7272

7373
tabddpm = ImputerDiffusion(
74-
model=TabDDPM(),
7574
epochs=10,
7675
batch_size=100,
7776
x_valid=df_data_valid,
@@ -160,12 +159,8 @@
160159
# reconstruction errors (mae) but increases distribution distance (kl_columnwise).
161160

162161
dict_imputers = {
163-
"num_sampling=5": ImputerDiffusion(
164-
model=TabDDPM(num_sampling=5), epochs=10, batch_size=100
165-
),
166-
"num_sampling=10": ImputerDiffusion(
167-
model=TabDDPM(num_sampling=10), epochs=10, batch_size=100
168-
),
162+
"num_sampling=5": ImputerDiffusion(epochs=10, batch_size=100, num_sampling=5),
163+
"num_sampling=10": ImputerDiffusion(epochs=10, batch_size=100, num_sampling=10),
169164
}
170165

171166
comparison = comparator.Comparator(
@@ -187,7 +182,7 @@
187182
#
188183
# Two important hyperparameters for processing time-series data are ``index_datetime``
189184
# and ``freq_str``.
190-
# E.g., ``ImputerDiffusion(model=TabDDPM(), index_datetime='datetime', freq_str='1D')``,
185+
# E.g., ``ImputerDiffusion(index_datetime='datetime', freq_str='1D')``,
191186
#
192187
# * ``index_datetime``: the column name of datetime in index. It must be a pandas datetime object.
193188
#
@@ -210,15 +205,16 @@
210205
# but requires a longer training/inference time.
211206

212207
dict_imputers = {
213-
"tabddpm": ImputerDiffusion(
214-
model=TabDDPM(num_sampling=5), epochs=10, batch_size=100
208+
"tabddpm": ImputerDiffusion(model="TabDDPM", epochs=10, batch_size=100, num_sampling=5
215209
),
216210
"tsddpm": ImputerDiffusion(
217-
model=TsDDPM(num_sampling=5, is_rolling=False),
211+
model="TsDDPM",
218212
epochs=10,
219213
batch_size=5,
220214
index_datetime="date",
221215
freq_str="5D",
216+
num_sampling=5,
217+
is_rolling=False
222218
),
223219
}
224220

pyproject.toml

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,30 +34,32 @@ classifiers = [
3434
# DEPENDENCIES
3535

3636
[tool.poetry.dependencies]
37-
python = ">=3.8.1,<3.12"
37+
python = ">=3.9,<3.13"
38+
hyperopt = "*"
39+
numpy = ">= 1.24"
40+
pandas = ">= 2.0.1"
41+
scipy = "*"
42+
scikit-learn = ">= 1.6"
43+
sphinx-markdown-tables = { version = "*", optional = true }
44+
statsmodels = ">= 0.14.0"
45+
typed-ast = { version = "*", optional = true }
46+
category-encoders = "^2.6.3"
47+
dcor = ">= 0.6"
48+
49+
[tool.poetry.group.torch.dependencies]
50+
torch = "< 2.5"
51+
52+
[tool.poetry.group.dev.dependencies]
3853
bump2version = "1.0.1"
54+
ipykernel = "^6.29.5"
3955
jupyter = "1.0.0"
4056
jupyterlab = "1.2.6"
4157
jupytext = "1.14.4"
42-
hyperopt = "0.2.7"
43-
numpy = "1.24.4"
58+
matplotlib = "*"
4459
packaging = "23.1"
45-
pandas = "2.0.1"
46-
scipy = "1.10.1"
47-
scikit-learn = "1.3.2"
48-
sphinx-markdown-tables = { version = "*", optional = true }
49-
statsmodels = "0.14.0"
50-
typed-ast = { version = "*", optional = true }
60+
pre-commit = "2.21.0"
5161
twine = "3.7.1"
5262
wheel = "0.37.1"
53-
category-encoders = "^2.6.3"
54-
ipykernel = "^6.29.5"
55-
torch = "*"
56-
dcor = "0.6"
57-
58-
[tool.poetry.group.dev.dependencies]
59-
matplotlib = "3.6.2"
60-
pre-commit = "2.21.0"
6163

6264
[tool.poetry.group.checkers.dependencies]
6365
bandit = "^1.7.9"
@@ -72,7 +74,7 @@ codecov = "^2.1.13"
7274

7375
[tool.poetry.group.docs.dependencies]
7476
numpydoc = "1.1.0"
75-
sphinx = "4.3.2"
77+
sphinx = ">= 5.0"
7678
sphinx-gallery = "0.10.1"
7779
sphinx_rtd_theme = "1.0.0"
7880
sphinx_markdown_tables = "0.0.17"
@@ -141,7 +143,15 @@ docstring-code-format = true
141143

142144
[tool.ruff.lint]
143145
select = ["C", "D", "E", "F", "I", "Q", "W"]
144-
ignore = ["C901", "D107", "D203", "D213"]
146+
ignore = [
147+
"C901",
148+
"D107",
149+
"D203",
150+
"D213",
151+
"N803", # allow X as a name for data
152+
"N806", # allow X as a name for data
153+
"N816", # allow mixed case names such as np_X_t as a name for data
154+
]
145155

146156
[tool.ruff.lint.isort]
147157
known-first-party = ["qolmat"]

0 commit comments

Comments
 (0)