2626)
2727
2828if TYPE_CHECKING :
29- from fairlearn .reductions import ExponentiatedGradient
29+ from fairlearn .reductions import ExponentiatedGradient # pyright: ignore
3030
3131 from ethicml .models .inprocess .agarwal_reductions import AgarwalArgs
3232 from ethicml .models .inprocess .in_subprocess import InAlgoArgs
3636def fit (train : DataTuple , args : AgarwalArgs , seed : int = 888 ) -> ExponentiatedGradient :
3737 """Fit a model."""
3838 try :
39- from fairlearn .reductions import (
39+ from fairlearn .reductions import ( # pyright: ignore
4040 DemographicParity ,
4141 EqualizedOdds ,
4242 ExponentiatedGradient ,
@@ -81,7 +81,7 @@ def fit(train: DataTuple, args: AgarwalArgs, seed: int = 888) -> ExponentiatedGr
8181 exponentiated_gradient .fit (data_x , data_y , sensitive_features = data_a )
8282
8383 min_class_label = train .y .min ()
84- exponentiated_gradient .min_class_label = min_class_label
84+ exponentiated_gradient .min_class_label = min_class_label # pyright: ignore
8585
8686 return exponentiated_gradient
8787
@@ -92,7 +92,7 @@ def predict(exponentiated_gradient: ExponentiatedGradient, test: TestTuple) -> p
9292 preds = pd .DataFrame (randomized_predictions , columns = ["preds" ])
9393
9494 if (min_val := preds ["preds" ].min ()) != preds ["preds" ].max ():
95- preds = preds .replace (min_val , exponentiated_gradient .min_class_label )
95+ preds = preds .replace (min_val , exponentiated_gradient .min_class_label ) # pyright: ignore
9696 return preds
9797
9898
@@ -120,7 +120,7 @@ def main() -> None:
120120 in_algo_args : InAlgoArgs = json .loads (sys .argv [1 ])
121121 flags : AgarwalArgs = json .loads (sys .argv [2 ])
122122 try :
123- import cloudpickle
123+ import cloudpickle # pyright: ignore
124124
125125 # Need to install cloudpickle for now. See https://github.com/fairlearn/fairlearn/issues/569
126126 except ImportError as e :
0 commit comments