@@ -294,9 +294,7 @@ class Model:
294294 predictive model to interface with the TrustyAI Java library.
295295 """
296296
297- def __init__ (
298- self , predict_fun , dataframe_input = False , output_names = None , disable_arrow = False
299- ):
297+ def __init__ (self , predict_fun , ** kwargs ):
300298 """
301299 Wrap the model as a TrustyAI :obj:`PredictionProvider` Java class.
302300
@@ -306,20 +304,26 @@ def __init__(
306304 A function that takes in a Numpy array or Pandas DataFrame as input and outputs a
307305 Pandas DataFrame or Numpy array. In general, the ``model.predict`` functions of
308306 sklearn-style models meet this requirement.
309- dataframe_input: bool
310- Whether `predict_fun` expects a :class:`pandas.DataFrame` as input.
311- output_names : List[String]:
312- If the model outputs a numpy array, you can specify the names of the model outputs
313- here.
314- disable_arrow: bool
315- If true, Apache Arrow will not be used to accelerate data transfer between Java
316- and Python. If false, Arrow will be automatically used in situations where it is
317- advantageous to do so.
307+
308+ Keyword Arguments:
309+ * dataframe_input: bool
310+ (default= ``False``) Whether `predict_fun` expects a :class:`pandas.DataFrame`
311+ as input.
312+ * feature_names : List[String]:
313+ (default= ``None`) If the model receives a non-pandas input, you can specify the
314+ names of the model input features here, with the ith element of the list
315+ corresponding to the name of the ith feature.
316+ * output_names : List[String]:
317+ (default= ``None`) If the model outputs a non-pandas object, you can specify the
318+ names of the model outputs here, with the ith element of the list corresponding to
319+ the name of the ith output.
320+ * disable_arrow: bool
321+ (default= ``False`) If true, Apache Arrow will not be used to accelerate data
322+ transfer between Java and Python. If false, Arrow will be automatically used in
323+ situations where it is advantageous to do so.
318324 """
319- self .disable_arrow = disable_arrow
320325 self .predict_fun = predict_fun
321- self .output_names = output_names
322- self .dataframe_input = dataframe_input
326+ self .kwargs = kwargs
323327
324328 self .prediction_provider_arrow = None
325329 self .prediction_provider_normal = None
@@ -328,6 +332,26 @@ def __init__(
328332 # set model to use non-arrow by default, as this requires no dataset information
329333 self ._set_nonarrow ()
330334
335+ @property
336+ def dataframe_input (self ):
337+ """Get dataframe_input kwarg value"""
338+ return self .kwargs .get ("dataframe_input" )
339+
340+ @property
341+ def feature_names (self ):
342+ """Get feature_names kwarg value"""
343+ return self .kwargs .get ("feature_names" )
344+
345+ @property
346+ def output_names (self ):
347+ """Get output_names kwarg value"""
348+ return self .kwargs .get ("output_names" )
349+
350+ @property
351+ def disable_arrow (self ):
352+ """Get disable_arrow kwarg value"""
353+ return self .kwargs .get ("disable_arrow" )
354+
331355 def _set_arrow (self , paradigm_input : PredictionInput ):
332356 """
333357 Ready the model for arrow-based prediction communication.
@@ -825,7 +849,10 @@ def feature(
825849# pylint: disable=line-too-long
826850@data_conversion_docstring ("one_input" , "one_output" )
827851def simple_prediction (
828- input_features : OneInputUnionType , outputs : OneOutputUnionType
852+ input_features : OneInputUnionType ,
853+ outputs : OneOutputUnionType ,
854+ feature_names : Optional [List [str ]] = None ,
855+ output_names : Optional [List [str ]] = None ,
829856) -> SimplePrediction :
830857 """Wrap features and outputs into a SimplePrediction. Given a list of features and outputs,
831858 this function will bundle them into Prediction objects for use with the LIME and SHAP
@@ -838,10 +865,15 @@ def simple_prediction(
838865 outputs : {}
839866 The desired model outputs to be searched for in the counterfactual explanation.
840867 These can take the form of a: {}
868+ feature_names: Optional[List[str]]
869+ The names of the features, in the case where the feature object does not contain them
870+ output_names: Optional[List[str]]
871+ The names of the outputs, in the case where the outputobject does not contain them
841872 """
842873
843874 return SimplePrediction (
844- one_input_convert (input_features ), one_output_convert (outputs )
875+ one_input_convert (input_features , feature_names ),
876+ one_output_convert (outputs , output_names ),
845877 )
846878
847879
@@ -850,6 +882,8 @@ def simple_prediction(
850882def counterfactual_prediction (
851883 input_features : OneInputUnionType ,
852884 outputs : OneOutputUnionType ,
885+ feature_names : Optional [List [str ]] = None ,
886+ output_names : Optional [List [str ]] = None ,
853887 data_distribution : Optional [DataDistribution ] = None ,
854888 uuid : Optional [_uuid .UUID ] = None ,
855889 timeout : Optional [float ] = None ,
@@ -865,6 +899,10 @@ def counterfactual_prediction(
865899 outputs : {}
866900 The desired model outputs to be searched for in the counterfactual explanation.
867901 These can take the form of a: {}
902+ feature_names: Optional[List[str]]
903+ The names of the features, in the case where the feature object does not contain them
904+ output_names: Optional[List[str]]
905+ The names of the outputs, in the case where the outputobject does not contain them
868906 data_distribution : Optional[:class:`DataDistribution`]
869907 The :class:`DataDistribution` to use when sampling the inputs.
870908 uuid : Optional[:class:`_uuid.UUID`]
@@ -878,8 +916,8 @@ def counterfactual_prediction(
878916 timeout = Long (timeout )
879917
880918 return CounterfactualPrediction (
881- one_input_convert (input_features ),
882- one_output_convert (outputs ),
919+ one_input_convert (input_features , feature_names ),
920+ one_output_convert (outputs , output_names ),
883921 data_distribution ,
884922 uuid ,
885923 timeout ,
0 commit comments