Skip to content

Commit 99cbcc3

Browse files
committed
Merge branch 'main' into release
2 parents da78f9b + dc7d3a6 commit 99cbcc3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+842
-846
lines changed

.github/workflows/continuous_integration.yml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ on:
1919
- LICENSE
2020
- make_release.sh
2121
- CITATION.cff
22+
merge_group:
2223

2324

2425
jobs:
@@ -39,10 +40,10 @@ jobs:
3940
pip install ruff
4041
- name: Lint with ruff
4142
run: |
42-
ruff ethicml
43+
ruff check --format=github ethicml
4344
- name: Lint with ruff
4445
run: |
45-
ruff tests
46+
ruff check --format=github tests
4647
4748
format_with_black:
4849

@@ -149,8 +150,10 @@ jobs:
149150
#----------------------------------------------
150151
- uses: actions/checkout@v3
151152
- name: Install poetry
153+
if: ${{ github.event_name == 'merge_group' }}
152154
run: pipx install poetry
153155
- uses: actions/setup-python@v4
156+
if: ${{ github.event_name == 'merge_group' }}
154157
with:
155158
python-version: '3.8'
156159
cache: 'poetry'
@@ -159,6 +162,7 @@ jobs:
159162
# --------- install dependencies --------
160163
#----------------------------------------------
161164
- name: Install dependencies
165+
if: ${{ github.event_name == 'merge_group' }}
162166
run: |
163167
# keep the following in sync with `test_full_dependencies`!
164168
poetry env use 3.8
@@ -168,6 +172,7 @@ jobs:
168172
# ----- Run MyPy -----
169173
#----------------------------------------------
170174
- name: Type check with mypy
175+
if: ${{ github.event_name == 'merge_group' }}
171176
run: |
172177
poetry run python run_mypy.py
173178
poetry run python run_mypy_tests.py
@@ -176,5 +181,6 @@ jobs:
176181
# ----- Run Tests -----
177182
#----------------------------------------------
178183
- name: Test with pytest
184+
if: ${{ github.event_name == 'merge_group' }}
179185
run: |
180186
poetry run python -m pytest -vv -n 2 --dist loadgroup --cov=ethicml --cov-fail-under=80 tests/

ethicml/common.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
"""Common variables / constants that make things run smoother."""
22
from importlib import util
3-
import os
43
from pathlib import Path
54

6-
__all__ = ["TORCH_AVAILABLE", "ROOT_DIR", "ROOT_PATH"]
5+
__all__ = ["TORCH_AVAILABLE", "ROOT_PATH"]
76

87
TORCH_AVAILABLE = util.find_spec("torch") is not None
98

10-
ROOT_DIR: str = os.path.abspath(os.path.join(os.path.abspath(__file__), os.pardir))
119
ROOT_PATH: Path = Path(__file__).parent.resolve()

ethicml/data/csvs/make_adult_from_raw.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def run_generate_adult() -> None:
3636
all_data = pd.concat([train, test], axis=0)
3737

3838
for col in all_data.columns:
39-
if all_data[col].dtype == np.object: # type: ignore[attr-defined]
39+
if all_data[col].dtype == object:
4040
all_data[col] = all_data[col].str.strip()
4141

4242
# Replace full stop in the label of the test set

ethicml/data/csvs/make_crime_from_raw.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def run_generate_crime() -> None:
160160
data.columns = pd.Index(columns)
161161

162162
for col in data.columns:
163-
if data[col].dtype == np.object: # type: ignore[attr-defined]
163+
if data[col].dtype == object:
164164
data[col] = data[col].str.strip()
165165

166166
# Drop NaNs

ethicml/data/dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing_extensions import override
1010

1111
import pandas as pd
12-
from ranzen import StrEnum
12+
from ranzen.misc import StrEnum
1313

1414
from ethicml.common import ROOT_PATH
1515
from ethicml.utility import DataTuple, undo_one_hot
@@ -417,6 +417,7 @@ class LegacyDataset(CSVDataset):
417417

418418
def __init__(
419419
self,
420+
*,
420421
name: str,
421422
filename_or_path: str | Path,
422423
features: Sequence[str],

ethicml/data/tabular_data/acs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@
8989

9090
@contextlib.contextmanager
9191
def _download_dir(root: Path) -> Generator[None, None, None]:
92-
curdir = os.getcwd()
92+
curdir = Path.cwd()
9393
os.chdir(root.expanduser().resolve())
9494
try:
9595
yield

ethicml/implementations/adv_debiasing_modules/model.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ class AdvDebiasingClassLearner:
192192

193193
def __init__(
194194
self,
195+
*,
195196
lr: float,
196197
n_clf_epochs: int,
197198
n_adv_epochs: int,
@@ -231,9 +232,9 @@ def __init__(
231232

232233
self.n_epoch_combined = n_epoch_combined
233234

234-
def fit(self, train: DataTuple, seed: int) -> Self: # type: ignore[valid-type]
235+
def fit(self, train: DataTuple, seed: int) -> Self:
235236
"""Fit."""
236-
train_data, train_loader = make_dataset_and_loader(
237+
_, train_loader = make_dataset_and_loader(
237238
train, batch_size=self.batch_size, shuffle=True, seed=seed, drop_last=True
238239
)
239240

@@ -268,9 +269,9 @@ def fit(self, train: DataTuple, seed: int) -> Self: # type: ignore[valid-type]
268269
@torch.no_grad()
269270
def predict(self, x: pd.DataFrame) -> np.ndarray:
270271
"""Predict."""
271-
x = torch.from_numpy(x.to_numpy()).float()
272+
x_ = torch.from_numpy(x.to_numpy()).float()
272273
self.clf.eval()
273-
yhat = self.clf(x)
274+
yhat = self.clf(x_)
274275
sm = nn.Softmax(dim=1)
275276
yhat = sm(yhat)
276277
yhat = yhat.detach().numpy()
@@ -283,6 +284,7 @@ class AdvDebiasingRegLearner:
283284

284285
def __init__(
285286
self,
287+
*,
286288
lr: float,
287289
n_clf_epochs: int,
288290
n_adv_epochs: int,
@@ -322,7 +324,7 @@ def __init__(
322324

323325
self.n_epoch_combined = n_epoch_combined
324326

325-
def fit(self, train: DataTuple, seed: int) -> Self: # type: ignore[valid-type]
327+
def fit(self, train: DataTuple, seed: int) -> Self:
326328
"""Fit."""
327329
# The features are X[:,1:]
328330

@@ -361,9 +363,9 @@ def fit(self, train: DataTuple, seed: int) -> Self: # type: ignore[valid-type]
361363
@torch.no_grad()
362364
def predict(self, x: pd.DataFrame) -> torch.Tensor:
363365
"""Predict."""
364-
x = torch.from_numpy(x.to_numpy()).float()
366+
x_ = torch.from_numpy(x.to_numpy()).float()
365367
self.clf.eval()
366-
yhat = self.clf(x).squeeze().detach().numpy()
368+
yhat = self.clf(x_).squeeze().detach().numpy()
367369
if self.out_shape == 1:
368370
out = yhat
369371
else:

ethicml/implementations/agarwal.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pathlib import Path
77
import random
88
import sys
9-
from typing import TYPE_CHECKING, Generator
9+
from typing import TYPE_CHECKING, Generator, Union
1010

1111
from joblib import dump, load
1212
import numpy as np
@@ -26,16 +26,17 @@
2626
)
2727

2828
if TYPE_CHECKING:
29-
from fairlearn.reductions import ExponentiatedGradient
29+
from fairlearn.reductions import ExponentiatedGradient # pyright: ignore
3030

3131
from ethicml.models.inprocess.agarwal_reductions import AgarwalArgs
3232
from ethicml.models.inprocess.in_subprocess import InAlgoArgs
33+
from ethicml.models.inprocess.shared import LinearModel
3334

3435

3536
def fit(train: DataTuple, args: AgarwalArgs, seed: int = 888) -> ExponentiatedGradient:
3637
"""Fit a model."""
3738
try:
38-
from fairlearn.reductions import (
39+
from fairlearn.reductions import ( # pyright: ignore
3940
DemographicParity,
4041
EqualizedOdds,
4142
ExponentiatedGradient,
@@ -50,13 +51,14 @@ def fit(train: DataTuple, args: AgarwalArgs, seed: int = 888) -> ExponentiatedGr
5051
fairness_class: UtilityParity
5152
fairness_type = FairnessType(args["fairness"])
5253
classifier_type = ClassifierType(args["classifier"])
53-
kernel_type = None if args["kernel"] == "" else KernelType[args["kernel"]]
54+
kernel_type = None if not args["kernel"] else KernelType[args["kernel"]]
5455

5556
if fairness_type is FairnessType.dp:
5657
fairness_class = DemographicParity(difference_bound=args["eps"])
5758
else:
5859
fairness_class = EqualizedOdds(difference_bound=args["eps"])
5960

61+
model: Union[LinearModel, GradientBoostingClassifier]
6062
if classifier_type is ClassifierType.svm:
6163
assert kernel_type is not None
6264
model = select_svm(C=args["C"], kernel=kernel_type, seed=seed)
@@ -79,7 +81,7 @@ def fit(train: DataTuple, args: AgarwalArgs, seed: int = 888) -> ExponentiatedGr
7981
exponentiated_gradient.fit(data_x, data_y, sensitive_features=data_a)
8082

8183
min_class_label = train.y.min()
82-
exponentiated_gradient.min_class_label = min_class_label
84+
exponentiated_gradient.min_class_label = min_class_label # pyright: ignore
8385

8486
return exponentiated_gradient
8587

@@ -90,7 +92,7 @@ def predict(exponentiated_gradient: ExponentiatedGradient, test: TestTuple) -> p
9092
preds = pd.DataFrame(randomized_predictions, columns=["preds"])
9193

9294
if (min_val := preds["preds"].min()) != preds["preds"].max():
93-
preds = preds.replace(min_val, exponentiated_gradient.min_class_label)
95+
preds = preds.replace(min_val, exponentiated_gradient.min_class_label) # pyright: ignore
9496
return preds
9597

9698

@@ -105,7 +107,7 @@ def train_and_predict(
105107
@contextlib.contextmanager
106108
def working_dir(root: Path) -> Generator[None, None, None]:
107109
"""Change the working directory to the given path."""
108-
curdir = os.getcwd()
110+
curdir = Path.cwd()
109111
os.chdir(root.expanduser().resolve().parent)
110112
try:
111113
yield
@@ -118,7 +120,7 @@ def main() -> None:
118120
in_algo_args: InAlgoArgs = json.loads(sys.argv[1])
119121
flags: AgarwalArgs = json.loads(sys.argv[2])
120122
try:
121-
import cloudpickle
123+
import cloudpickle # pyright: ignore
122124

123125
# Need to install cloudpickle for now. See https://github.com/fairlearn/fairlearn/issues/569
124126
except ImportError as e:

ethicml/implementations/beutel.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,13 @@ def fit(train: DataTuple, flags: BeutelArgs, seed: int = 888) -> tuple[DataTuple
8080
set_seed(seed)
8181
fairness = FairnessType[flags["fairness"]]
8282

83-
post_process = False
83+
processor: LabelBinarizer | None = None
8484
if flags["y_loss"] == "BCELoss()":
8585
try:
8686
assert_binary_labels(train)
8787
except AssertionError:
8888
processor = LabelBinarizer()
8989
train = processor.adjust(train)
90-
post_process = True
9190

9291
# By default we use 10% of the training data for validation
9392
train_, validation = train_test_split(train, train_percentage=1 - flags["validation_pcnt"])
@@ -136,6 +135,8 @@ def fit(train: DataTuple, flags: BeutelArgs, seed: int = 888) -> tuple[DataTuple
136135
raise NotImplementedError("Not implemented Eq. Odds yet")
137136
elif fairness is FairnessType.dp:
138137
mask = torch.ones(s_pred.shape, dtype=torch.uint8)
138+
else:
139+
raise NotImplementedError(f"Unknown value: {fairness}")
139140
loss += s_loss_fn(
140141
s_pred, torch.masked_select(sens_label, mask).view(-1, int(train_data.sdim))
141142
)
@@ -169,8 +170,8 @@ def fit(train: DataTuple, flags: BeutelArgs, seed: int = 888) -> tuple[DataTuple
169170
enc.load_state_dict(best_enc)
170171

171172
transformed_train = encode_dataset(enc, all_train_data_loader, train)
172-
if post_process:
173-
transformed_train = processor.post(encode_dataset(enc, all_train_data_loader, train))
173+
if processor is not None:
174+
transformed_train = processor.post(transformed_train)
174175
return transformed_train, enc
175176

176177

@@ -207,6 +208,8 @@ def get_mask(flags: BeutelArgs, s_pred: Tensor, class_label: Tensor) -> Tensor:
207208
raise NotImplementedError("Not implemented Eq. Odds yet")
208209
elif fairness is FairnessType.dp:
209210
mask = torch.ones(s_pred.shape, dtype=torch.uint8)
211+
else:
212+
raise NotImplementedError("Shouldn't be hit.")
210213
return mask
211214

212215

0 commit comments

Comments
 (0)