Skip to content

Commit 30c2ddc

Browse files
committed
Update type stubs
1 parent 7e90e14 commit 30c2ddc

File tree

9 files changed

+22
-25
lines changed

9 files changed

+22
-25
lines changed

ethicml/implementations/agarwal.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727

2828
if TYPE_CHECKING:
29-
from fairlearn.reductions import ExponentiatedGradient
29+
from fairlearn.reductions import ExponentiatedGradient # pyright: ignore
3030

3131
from ethicml.models.inprocess.agarwal_reductions import AgarwalArgs
3232
from ethicml.models.inprocess.in_subprocess import InAlgoArgs
@@ -36,7 +36,7 @@
3636
def fit(train: DataTuple, args: AgarwalArgs, seed: int = 888) -> ExponentiatedGradient:
3737
"""Fit a model."""
3838
try:
39-
from fairlearn.reductions import (
39+
from fairlearn.reductions import ( # pyright: ignore
4040
DemographicParity,
4141
EqualizedOdds,
4242
ExponentiatedGradient,
@@ -81,7 +81,7 @@ def fit(train: DataTuple, args: AgarwalArgs, seed: int = 888) -> ExponentiatedGr
8181
exponentiated_gradient.fit(data_x, data_y, sensitive_features=data_a)
8282

8383
min_class_label = train.y.min()
84-
exponentiated_gradient.min_class_label = min_class_label
84+
exponentiated_gradient.min_class_label = min_class_label # pyright: ignore
8585

8686
return exponentiated_gradient
8787

@@ -92,7 +92,7 @@ def predict(exponentiated_gradient: ExponentiatedGradient, test: TestTuple) -> p
9292
preds = pd.DataFrame(randomized_predictions, columns=["preds"])
9393

9494
if (min_val := preds["preds"].min()) != preds["preds"].max():
95-
preds = preds.replace(min_val, exponentiated_gradient.min_class_label)
95+
preds = preds.replace(min_val, exponentiated_gradient.min_class_label) # pyright: ignore
9696
return preds
9797

9898

@@ -120,7 +120,7 @@ def main() -> None:
120120
in_algo_args: InAlgoArgs = json.loads(sys.argv[1])
121121
flags: AgarwalArgs = json.loads(sys.argv[2])
122122
try:
123-
import cloudpickle
123+
import cloudpickle # pyright: ignore
124124

125125
# Need to install cloudpickle for now. See https://github.com/fairlearn/fairlearn/issues/569
126126
except ImportError as e:

ethicml/metrics/accuracy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,23 +30,23 @@ def score(self, prediction: Prediction, actual: EvalTuple) -> float:
3030
class Accuracy(SklearnMetric):
3131
"""Classification accuracy."""
3232

33-
sklearn_metric = (accuracy_score,)
33+
sklearn_metric = (accuracy_score,) # type: ignore[assignment]
3434
_name: ClassVar[str] = "Accuracy"
3535

3636

3737
@dataclass
3838
class F1(SklearnMetric):
3939
"""F1 score: harmonic mean of precision and recall."""
4040

41-
sklearn_metric = (f1_score,)
41+
sklearn_metric = (f1_score,) # type: ignore[assignment]
4242
_name: ClassVar[str] = "F1"
4343

4444

4545
@dataclass
4646
class RobustAccuracy(SklearnMetric):
4747
"""Minimum Classification accuracy across S-groups."""
4848

49-
sklearn_metric = (accuracy_score,)
49+
sklearn_metric = (accuracy_score,) # type: ignore[assignment]
5050
apply_per_sensitive: ClassVar[bool] = False
5151
_name: ClassVar[str] = "Robust Accuracy"
5252

ethicml/metrics/dependence_measures.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class NMI(_DependenceMeasure):
4949
@override
5050
def score(self, prediction: Prediction, actual: EvalTuple) -> float:
5151
base_values = actual.y if self.base is DependencyTarget.y else actual.s
52-
return normalized_mutual_info_score(
52+
return normalized_mutual_info_score( # type: ignore[return-value]
5353
base_values.to_numpy().ravel(),
5454
prediction.hard.to_numpy().ravel(),
5555
average_method="arithmetic",

ethicml/models/inprocess/kamiran.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def _predict(self, model: LinearModel, test: TestTuple) -> Prediction:
130130
hard=pd.Series(model.predict(test.x.to_numpy())), info=self.hyperparameters
131131
)
132132
return SoftPrediction(
133-
soft=model.predict_proba(test.x.to_numpy()), # type: ignore[arg-type]
133+
soft=model.predict_proba(test.x.to_numpy()),
134134
info=self.hyperparameters,
135135
)
136136

ethicml/models/inprocess/logistic_regression.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def run(self, train: DataTuple, test: TestTuple, seed: int = 888) -> SoftPredict
5555
)
5656
clf.fit(train.x, train.y.to_numpy().ravel())
5757
return SoftPrediction(
58-
soft=clf.predict_proba(test.x.to_numpy()), # type: ignore[arg-type]
58+
soft=clf.predict_proba(test.x.to_numpy()),
5959
info=self.hyperparameters,
6060
)
6161

@@ -88,9 +88,9 @@ def fit(self, train: DataTuple, seed: int = 888) -> LRCV:
8888
@override
8989
def predict(self, test: TestTuple) -> Prediction:
9090
params = self.hyperparameters
91-
params["C"] = self.clf.C_[0] # type: ignore[attr-defined]
91+
params["C"] = self.clf.C_[0]
9292
return SoftPrediction(
93-
soft=self.clf.predict_proba(test.x.to_numpy()), # type: ignore[arg-type]
93+
soft=self.clf.predict_proba(test.x.to_numpy()),
9494
info=params,
9595
)
9696

@@ -103,8 +103,8 @@ def run(self, train: DataTuple, test: TestTuple, seed: int = 888) -> Prediction:
103103
)
104104
clf.fit(train.x, train.y.to_numpy().ravel())
105105
params = self.hyperparameters
106-
params["C"] = clf.C_[0] # type: ignore[attr-defined]
106+
params["C"] = clf.C_[0]
107107
return SoftPrediction(
108-
soft=clf.predict_proba(test.x.to_numpy()), # type: ignore[arg-type]
108+
soft=clf.predict_proba(test.x.to_numpy()),
109109
info=params,
110110
)

ethicml/plot/plotting.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,7 @@ def _maybe_tsne(data: DataTuple) -> tuple[pd.DataFrame, str, str]:
3232
if len(columns) > 2:
3333
from sklearn.manifold import TSNE
3434

35-
tsne_embeddings = TSNE(n_components=2, random_state=0).fit_transform(
36-
data.x, # type: ignore[arg-type]
37-
)
35+
tsne_embeddings = TSNE(n_components=2, random_state=0).fit_transform(data.x)
3836
amalgamated = pd.concat(
3937
[pd.DataFrame(tsne_embeddings, columns=["tsne1", "tsne2"]), data.s, data.y],
4038
axis="columns",

ethicml/preprocessing/splits.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ def train_test_split(
116116
pd.testing.assert_index_equal(train.x.columns, x_columns)
117117
pd.testing.assert_index_equal(test.x.columns, x_columns)
118118

119-
assert isinstance(train.s, pd.Series)
120-
assert isinstance(test.s, pd.Series)
119+
assert isinstance(train.s, pd.Series) # type: ignore[unreachable]
120+
assert isinstance(test.s, pd.Series) # type: ignore[unreachable]
121121
assert train.s.name == data.s_column
122122
assert test.s.name == data.s_column
123123

poetry.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ mypy = ">=0.990"
5454
pre-commit = "^2.20.0"
5555
pytest = ">=6.2.2,<8.0.0"
5656
pytest-cov = ">=2.6,<4.0"
57-
python-type-stubs = {git = "https://github.com/wearepal/python-type-stubs.git", rev = "417b03b"}
57+
python-type-stubs = {git = "https://github.com/wearepal/python-type-stubs.git", rev = "2ea8053"}
5858
pandas-stubs = ">=1.4.2.220626"
5959
omegaconf = ">=2.2.1"
6060
pytest-xdist = "^2.5.0"
@@ -173,7 +173,6 @@ module = [
173173
"pytest.*",
174174
"setuptools.*",
175175
"scipy.spatial.distance",
176-
"sklearn.*",
177176
]
178177
ignore_missing_imports = true
179178

0 commit comments

Comments
 (0)