66from pathlib import Path
77import random
88import sys
9- from typing import TYPE_CHECKING , Generator
9+ from typing import TYPE_CHECKING , Generator , Union
1010
1111from joblib import dump , load
1212import numpy as np
2626)
2727
2828if TYPE_CHECKING :
29- from fairlearn .reductions import ExponentiatedGradient
29+ from fairlearn .reductions import ExponentiatedGradient # pyright: ignore
3030
3131 from ethicml .models .inprocess .agarwal_reductions import AgarwalArgs
3232 from ethicml .models .inprocess .in_subprocess import InAlgoArgs
33+ from ethicml .models .inprocess .shared import LinearModel
3334
3435
3536def fit (train : DataTuple , args : AgarwalArgs , seed : int = 888 ) -> ExponentiatedGradient :
3637 """Fit a model."""
3738 try :
38- from fairlearn .reductions import (
39+ from fairlearn .reductions import ( # pyright: ignore
3940 DemographicParity ,
4041 EqualizedOdds ,
4142 ExponentiatedGradient ,
@@ -50,13 +51,14 @@ def fit(train: DataTuple, args: AgarwalArgs, seed: int = 888) -> ExponentiatedGr
5051 fairness_class : UtilityParity
5152 fairness_type = FairnessType (args ["fairness" ])
5253 classifier_type = ClassifierType (args ["classifier" ])
53- kernel_type = None if args ["kernel" ] == "" else KernelType [args ["kernel" ]]
54+ kernel_type = None if not args ["kernel" ] else KernelType [args ["kernel" ]]
5455
5556 if fairness_type is FairnessType .dp :
5657 fairness_class = DemographicParity (difference_bound = args ["eps" ])
5758 else :
5859 fairness_class = EqualizedOdds (difference_bound = args ["eps" ])
5960
61+ model : Union [LinearModel , GradientBoostingClassifier ]
6062 if classifier_type is ClassifierType .svm :
6163 assert kernel_type is not None
6264 model = select_svm (C = args ["C" ], kernel = kernel_type , seed = seed )
@@ -79,7 +81,7 @@ def fit(train: DataTuple, args: AgarwalArgs, seed: int = 888) -> ExponentiatedGr
7981 exponentiated_gradient .fit (data_x , data_y , sensitive_features = data_a )
8082
8183 min_class_label = train .y .min ()
82- exponentiated_gradient .min_class_label = min_class_label
84+ exponentiated_gradient .min_class_label = min_class_label # pyright: ignore
8385
8486 return exponentiated_gradient
8587
@@ -90,7 +92,7 @@ def predict(exponentiated_gradient: ExponentiatedGradient, test: TestTuple) -> p
9092 preds = pd .DataFrame (randomized_predictions , columns = ["preds" ])
9193
9294 if (min_val := preds ["preds" ].min ()) != preds ["preds" ].max ():
93- preds = preds .replace (min_val , exponentiated_gradient .min_class_label )
95+ preds = preds .replace (min_val , exponentiated_gradient .min_class_label ) # pyright: ignore
9496 return preds
9597
9698
@@ -105,7 +107,7 @@ def train_and_predict(
105107@contextlib .contextmanager
106108def working_dir (root : Path ) -> Generator [None , None , None ]:
107109 """Change the working directory to the given path."""
108- curdir = os . getcwd ()
110+ curdir = Path . cwd ()
109111 os .chdir (root .expanduser ().resolve ().parent )
110112 try :
111113 yield
@@ -118,7 +120,7 @@ def main() -> None:
118120 in_algo_args : InAlgoArgs = json .loads (sys .argv [1 ])
119121 flags : AgarwalArgs = json .loads (sys .argv [2 ])
120122 try :
121- import cloudpickle
123+ import cloudpickle # pyright: ignore
122124
123125 # Need to install cloudpickle for now. See https://github.com/fairlearn/fairlearn/issues/569
124126 except ImportError as e :
0 commit comments