@@ -295,7 +295,7 @@ class Model:
295295 """
296296
297297 def __init__ (
298- self , predict_fun , dataframe_input = False , output_names = None , arrow = False
298+ self , predict_fun , dataframe_input = False , output_names = None , disable_arrow = False
299299 ):
300300 """
301301 Wrap the model as a TrustyAI :obj:`PredictionProvider` Java class.
@@ -311,39 +311,75 @@ def __init__(
311311 output_names : List[String]:
312312 If the model outputs a numpy array, you can specify the names of the model outputs
313313 here.
314- arrow : bool
315- Whether to use Apache arrow to speed up data transfer between Java and Python.
316- In general, set this to ``true`` whenever LIME or SHAP explanations are needed,
317- and ``false`` for counterfactuals .
314+ disable_arrow : bool
315+ If true, Apache Arrow will not be used to accelerate data transfer between Java
316+ and Python. If false, Arrow will be automatically used in situations where it is
317+ advantageous to do so .
318318 """
319- self .arrow = arrow
319+ self .disable_arrow = disable_arrow
320320 self .predict_fun = predict_fun
321321 self .output_names = output_names
322+ self .dataframe_input = dataframe_input
322323
323- if arrow :
324- self .prediction_provider = None
325- if not dataframe_input :
326- self .prediction_provider_arrow = PredictionProviderArrow (
327- lambda x : self ._cast_outputs_to_dataframe (predict_fun (x .values ))
328- )
329- else :
330- self .prediction_provider_arrow = PredictionProviderArrow (
331- lambda x : self ._cast_outputs_to_dataframe (predict_fun (x ))
324+ self .prediction_provider_arrow = None
325+ self .prediction_provider_normal = None
326+ self .prediction_provider = None
327+
328+ # set model to use non-arrow by default, as this requires no dataset information
329+ self ._set_nonarrow ()
330+
331+ def _set_arrow (self , paradigm_input : PredictionInput ):
332+ """
333+ Ready the model for arrow-based prediction communication.
334+
335+ Parameters
336+ ----------
337+ paradigm_input: A single :obj:`PredictionInput` by which to establish the arrow schema.
338+ All subsequent :obj:`PredictionInput`s communicated must have this schema.
339+ """
340+ if self .disable_arrow :
341+ self ._set_nonarrow ()
342+ else :
343+ if self .prediction_provider_arrow is None :
344+ raw_ppa = self ._get_arrow_prediction_provider ()
345+ self .prediction_provider_arrow = raw_ppa .get_as_prediction_provider (
346+ paradigm_input
332347 )
348+ self .prediction_provider = self .prediction_provider_arrow
349+
350+ def _set_nonarrow (self ):
351+ """
352+ Ready the model for non-arrow-prediction communication.
353+ """
354+ if self .prediction_provider_normal is None :
355+ self .prediction_provider_normal = self ._get_nonarrow_prediction_provider ()
356+ self .prediction_provider = self .prediction_provider_normal
357+
358+ def _get_arrow_prediction_provider (self ):
359+ if not self .dataframe_input :
360+ ppa = PredictionProviderArrow (
361+ lambda x : self ._cast_outputs_to_dataframe (self .predict_fun (x .values ))
362+ )
333363 else :
334- self .prediction_provider_arrow = None
335- if dataframe_input :
336- self .prediction_provider = PredictionProvider (
337- lambda x : self ._cast_outputs (
338- predict_fun (prediction_object_to_pandas (x ))
339- )
364+ ppa = PredictionProviderArrow (
365+ lambda x : self ._cast_outputs_to_dataframe (self .predict_fun (x ))
366+ )
367+ return ppa
368+
369+ def _get_nonarrow_prediction_provider (self ):
370+ if self .dataframe_input :
371+ ppn = PredictionProvider (
372+ lambda x : self ._cast_outputs (
373+ self .predict_fun (prediction_object_to_pandas (x ))
340374 )
341- else :
342- self . prediction_provider = PredictionProvider (
343- lambda x : self . _cast_outputs (
344- predict_fun ( prediction_object_to_numpy ( x ))
345- )
375+ )
376+ else :
377+ ppn = PredictionProvider (
378+ lambda x : self . _cast_outputs (
379+ self . predict_fun ( prediction_object_to_numpy ( x ) )
346380 )
381+ )
382+ return ppn
347383
348384 def _cast_outputs (self , output_array ):
349385 return df_to_prediction_object (
@@ -388,12 +424,8 @@ def predictAsync(self, inputs: List[PredictionInput]) -> CompletableFuture:
388424 :obj:`CompletableFuture`
389425 A Java :obj:`CompletableFuture` containing the model outputs.
390426 """
391- if self .arrow and self .prediction_provider is None :
392- self .prediction_provider = (
393- self .prediction_provider_arrow .get_as_prediction_provider (inputs [0 ])
394- )
395- out = self .prediction_provider .predictAsync (inputs )
396- return out
427+
428+ return self .prediction_provider .predictAsync (inputs )
397429
398430 def __call__ (self , inputs ):
399431 """
@@ -405,6 +437,51 @@ def __call__(self, inputs):
405437 """
406438 return self .predict_fun (inputs )
407439
440+ class ArrowTransmission :
441+ """
442+ Context class to ensure all predictAsync calls within the context use arrow.
443+
444+ Parameters
445+ ----------
446+ model: The TrustyAI :obj:`Model` or PredictionProvider
447+ paradigm_input: A single :obj:`PredictionInput` by which to establish the arrow schema.
448+ All subsequent :obj:`PredictionInput`s communicated must have this schema.
449+ """
450+
451+ def __init__ (self , model , paradigm_input : OneInputUnionType ):
452+ self .model = model
453+ self .model_is_python = isinstance (model , Model )
454+ self .paradigm_input = one_input_convert (paradigm_input )
455+ self .previous_model_state = None
456+
457+ def __enter__ (self ):
458+ if self .model_is_python :
459+ self .previous_model_state = self .model .prediction_provider
460+ self .model ._set_arrow (self .paradigm_input )
461+
462+ def __exit__ (self , exit_type , value , traceback ):
463+ if self .model_is_python :
464+ self .model .prediction_provider = self .previous_model_state
465+
466+ class NonArrowTransmission :
467+ """
468+ Context class to ensure all predictAsync calls within the context DO NOT use arrow.
469+ """
470+
471+ def __init__ (self , model ):
472+ self .model = model
473+ self .model_is_python = isinstance (model , Model )
474+ self .previous_model_state = None
475+
476+ def __enter__ (self ):
477+ if self .model_is_python :
478+ self .previous_model_state = self .model .prediction_provider
479+ self .model ._set_nonarrow ()
480+
481+ def __exit__ (self , exit_type , value , traceback ):
482+ if self .model_is_python :
483+ self .model .prediction_provider = self .previous_model_state
484+
408485
409486@_jcustomizer .JImplementationFor ("org.kie.trustyai.explainability.model.Output" )
410487# pylint: disable=no-member
0 commit comments