Skip to content

Commit ba255b3

Browse files
authored
cleanup (#2)
* cleanup * update workflows * precommit * fixes
1 parent ea54572 commit ba255b3

File tree

10 files changed

+130
-90
lines changed

10 files changed

+130
-90
lines changed

.github/workflows/release.yml

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
name: Package Release
2+
3+
on:
4+
release:
5+
types: [created]
6+
7+
8+
jobs:
9+
deploy_osx:
10+
runs-on: ${{ matrix.os }}
11+
strategy:
12+
matrix:
13+
python-version: ["3.7", "3.8", "3.9", "3.10"]
14+
os: [macos-latest]
15+
16+
steps:
17+
- uses: actions/checkout@v2
18+
with:
19+
submodules: true
20+
- name: Set up Python
21+
uses: actions/setup-python@v1
22+
with:
23+
python-version: ${{ matrix.python-version }}
24+
- name: Build and publish
25+
env:
26+
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
27+
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
28+
run: ${GITHUB_WORKSPACE}/.github/workflows/scripts/release_osx.sh
29+
30+
deploy_linux:
31+
strategy:
32+
matrix:
33+
python-version:
34+
- cp37-cp37m
35+
- cp38-cp38
36+
- cp39-cp39
37+
- cp310-cp310
38+
39+
runs-on: ubuntu-latest
40+
container: quay.io/pypa/manylinux2014_x86_64
41+
steps:
42+
- uses: actions/checkout@v1
43+
with:
44+
submodules: true
45+
- name: Set target Python version PATH
46+
run: |
47+
echo "/opt/python/${{ matrix.python-version }}/bin" >> $GITHUB_PATH
48+
- name: Build and publish
49+
env:
50+
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
51+
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
52+
run: ${GITHUB_WORKSPACE}/.github/workflows/scripts/release_linux.sh
53+
54+
deploy_windows:
55+
runs-on: windows-latest
56+
strategy:
57+
matrix:
58+
python-version: ["3.7", "3.8", "3.9", "3.10"]
59+
60+
steps:
61+
- uses: actions/checkout@v2
62+
with:
63+
submodules: true
64+
- name: Set up Python ${{ matrix.python-version }}
65+
uses: actions/setup-python@v1
66+
with:
67+
python-version: ${{ matrix.python-version }}
68+
- name: Build and publish
69+
env:
70+
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
71+
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
72+
run: |
73+
../../.github/workflows/scripts/release_windows.bat
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#!/bin/bash
2+
3+
set -e
4+
5+
yum makecache -y
6+
yum install centos-release-scl -y
7+
yum-config-manager --enable rhel-server-rhscl-7-rpms
8+
yum install llvm-toolset-7.0 python3 python3-devel -y
9+
10+
# Python
11+
python3 -m pip install --upgrade pip
12+
python3 -m pip install setuptools wheel twine auditwheel
13+
14+
# Publish
15+
python3 -m pip wheel . -w dist/ --no-deps
16+
twine upload --verbose --skip-existing dist/*
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#!/bin/sh
2+
3+
export MACOSX_DEPLOYMENT_TARGET=10.14
4+
5+
python -m pip install --upgrade pip
6+
pip install setuptools wheel twine auditwheel
7+
8+
python3 setup.py build bdist_wheel --plat-name macosx_10_14_x86_64 --dist-dir wheel
9+
twine upload --skip-existing wheel/*
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
echo on
2+
3+
python -m pip install --upgrade pip
4+
pip install setuptools wheel twine auditwheel
5+
6+
pip wheel . -w wheel/ --no-deps
7+
twine upload --skip-existing wheel/*

.github/workflows/test_decaf.yml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
name: DECAF
1+
name: Tests
22

33
on:
44
push:
55
branches: [main, release]
66
pull_request:
77
types: [opened, synchronize, reopened]
8+
schedule:
9+
- cron: '6 5 * * 3'
10+
811

912
jobs:
1013
Linter:
@@ -34,7 +37,7 @@ jobs:
3437
runs-on: ${{ matrix.os }}
3538
strategy:
3639
matrix:
37-
python-version: [3.6, 3.7, 3.8, 3.9]
40+
python-version: ["3.7", "3.8", "3.9", "3.10"]
3841
os: [macos-latest, ubuntu-latest, windows-latest]
3942
steps:
4043
- uses: actions/checkout@v2

.pre-commit-config.yaml

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
exclude: 'setup.py|^docs'
22

33
repos:
4-
- repo: git://github.com/pre-commit/pre-commit-hooks
4+
- repo: https://github.com/pre-commit/pre-commit-hooks
55
rev: v3.4.0
66
hooks:
77
- id: trailing-whitespace
@@ -24,34 +24,18 @@ repos:
2424
- id: isort
2525

2626
- repo: https://github.com/psf/black
27-
rev: 21.4b0
27+
rev: 22.3.0
2828
hooks:
2929
- id: black
3030
language_version: python3
31-
- repo: https://gitlab.com/pycqa/flake8
31+
- repo: https://github.com/pycqa/flake8
3232
rev: 3.9.1
3333
hooks:
3434
- id: flake8
3535
args: [
36-
"--max-line-length=340",
36+
"--max-line-length=440",
3737
"--extend-ignore=E203,W503"
3838
]
39-
- repo: https://github.com/pre-commit/mirrors-mypy
40-
rev: v0.812
41-
hooks:
42-
- id: mypy
43-
args: [
44-
"--ignore-missing-imports",
45-
"--scripts-are-modules",
46-
"--disallow-incomplete-defs",
47-
"--no-implicit-optional",
48-
"--warn-unused-ignores",
49-
"--warn-redundant-casts",
50-
"--strict-equality",
51-
"--warn-unreachable",
52-
"--disallow-untyped-defs",
53-
"--disallow-untyped-calls",
54-
]
5539
- repo: local
5640
hooks:
5741
- id: flynt

decaf/DECAF.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,16 @@
1010

1111
import decaf.logger as log
1212

13+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14+
1315

1416
class TraceExpm(torch.autograd.Function):
1517
@staticmethod
1618
def forward(ctx: Any, input: torch.Tensor) -> torch.Tensor:
1719
# detach so we can cast to NumPy
1820
E = slin.expm(input.detach().numpy())
1921
f = np.trace(E)
20-
E = torch.from_numpy(E)
22+
E = torch.from_numpy(E).to(DEVICE)
2123
ctx.save_for_backward(E)
2224
return torch.as_tensor(f, dtype=input.dtype)
2325

@@ -183,8 +185,8 @@ def __init__(
183185
h_dim=h_dim,
184186
use_mask=use_mask,
185187
dag_seed=dag_seed,
186-
)
187-
self.discriminator = Discriminator(x_dim=self.x_dim, h_dim=h_dim)
188+
).to(DEVICE)
189+
self.discriminator = Discriminator(x_dim=self.x_dim, h_dim=h_dim).to(DEVICE)
188190

189191
self.dag_seed = dag_seed
190192

@@ -216,7 +218,7 @@ def gradient_dag_loss(self, x: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
216218
)[0]
217219
W[i] = torch.sum(torch.abs(gradients), axis=0)
218220

219-
h = trace_expm(W ** 2) - self.hparams.x_dim
221+
h = trace_expm(W**2) - self.hparams.x_dim
220222

221223
return 0.5 * self.hparams.rho * h * h + self.hparams.alpha * h
222224

@@ -277,10 +279,10 @@ def get_W(self) -> torch.Tensor:
277279

278280
def dag_loss(self) -> torch.Tensor:
279281
W = self.get_W()
280-
h = trace_expm(W ** 2) - self.x_dim
282+
h = trace_expm(W**2) - self.x_dim
281283
l1_loss = torch.norm(W, 1)
282284
return (
283-
0.5 * self.hparams.rho * h ** 2
285+
0.5 * self.hparams.rho * h**2
284286
+ self.hparams.alpha * h
285287
+ self.hparams.l1_W * l1_loss
286288
)

decaf/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.1.1"
1+
__version__ = "0.1.2"

requirements.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
loguru
2-
networkx == 2.*
3-
numpy == 1.19.*
2+
networkx>=2.0
3+
numpy>=1.19
44
pandas
5-
pytorch-lightning == 1.4.*
5+
pytorch-lightning>=1.4
66
scipy
77
sklearn
8-
torch == 1.9.*
9-
torchtext == 0.10.*
8+
torch>=1.9
9+
torchtext>=0.10.*
1010
xgboost

tests/test_decaf.py

Lines changed: 2 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,9 @@
22

33
import networkx as nx
44
import numpy as np
5-
import pandas as pd
6-
import pytest
75
import pytorch_lightning as pl
86
import torch
9-
from sklearn.metrics import precision_score, recall_score, roc_auc_score
10-
from utils import gen_data_nonlinear, load_adult
11-
from xgboost import XGBClassifier
7+
from utils import gen_data_nonlinear
128

139
from decaf import DECAF, DataModule
1410

@@ -72,7 +68,7 @@ def test_sanity_generate() -> None:
7268
dummy_dm.dims[0],
7369
dag_seed=seed,
7470
)
75-
trainer = pl.Trainer(max_epochs=2, logger=False)
71+
trainer = pl.Trainer(max_epochs=100, logger=True)
7672

7773
trainer.fit(model, dummy_dm)
7874

@@ -84,53 +80,3 @@ def test_sanity_generate() -> None:
8480
.numpy()
8581
)
8682
assert synth_data.shape[0] == 10
87-
88-
89-
@pytest.mark.parametrize("X,y", [load_adult()])
90-
@pytest.mark.slow
91-
def test_run_experiments(X: pd.DataFrame, y: pd.DataFrame) -> None:
92-
baseline_clf = XGBClassifier().fit(X, y)
93-
y_pred = baseline_clf.predict(X)
94-
95-
print(
96-
"baseline scores",
97-
precision_score(y, y_pred),
98-
recall_score(y, y_pred),
99-
roc_auc_score(y, y_pred),
100-
)
101-
102-
dm = DataModule(X)
103-
104-
model = DECAF(
105-
dm.dims[0],
106-
use_mask=True,
107-
grad_dag_loss=False,
108-
lambda_privacy=0,
109-
lambda_gp=10,
110-
weight_decay=1e-2,
111-
l1_g=0,
112-
p_gen=-1,
113-
batch_size=100,
114-
)
115-
trainer = pl.Trainer(max_epochs=10, logger=False)
116-
trainer.fit(model, dm)
117-
118-
X_synth = (
119-
model.gen_synthetic(
120-
dm.dataset.x,
121-
gen_order=model.get_gen_order(),
122-
)
123-
.detach()
124-
.numpy()
125-
)
126-
y_synth = baseline_clf.predict(X_synth)
127-
128-
synth_clf = XGBClassifier().fit(X_synth, y_synth)
129-
y_pred = synth_clf.predict(X_synth)
130-
131-
print(
132-
"synth scores",
133-
precision_score(y_synth, y_pred),
134-
recall_score(y_synth, y_pred),
135-
roc_auc_score(y_synth, y_pred),
136-
)

0 commit comments

Comments
 (0)