@@ -295,7 +295,7 @@ class Model:
295295    """ 
296296
297297    def  __init__ (
298-         self , predict_fun , dataframe_input = False , output_names = None , arrow = False 
298+         self , predict_fun , dataframe_input = False , output_names = None , disable_arrow = False 
299299    ):
300300        """ 
301301        Wrap the model as a TrustyAI :obj:`PredictionProvider` Java class. 
@@ -311,39 +311,75 @@ def __init__(
311311        output_names : List[String]: 
312312            If the model outputs a numpy array, you can specify the names of the model outputs 
313313            here. 
314-         arrow : bool 
315-             Whether to use  Apache arrow to speed up  data transfer between Java and Python.  
316-             In general, set this to ``true`` whenever LIME or SHAP explanations are needed,  
317-             and ``false`` for counterfactuals . 
314+         disable_arrow : bool 
315+             If true,  Apache Arrow will not be used to accelerate  data transfer between Java 
316+             and Python. If false, Arrow will be automatically used in situations where it is  
317+             advantageous to do so . 
318318        """ 
319-         self .arrow  =  arrow 
319+         self .disable_arrow  =  disable_arrow 
320320        self .predict_fun  =  predict_fun 
321321        self .output_names  =  output_names 
322+         self .dataframe_input  =  dataframe_input 
322323
323-         if  arrow :
324-             self .prediction_provider  =  None 
325-             if  not  dataframe_input :
326-                 self .prediction_provider_arrow  =  PredictionProviderArrow (
327-                     lambda  x : self ._cast_outputs_to_dataframe (predict_fun (x .values ))
328-                 )
329-             else :
330-                 self .prediction_provider_arrow  =  PredictionProviderArrow (
331-                     lambda  x : self ._cast_outputs_to_dataframe (predict_fun (x ))
324+         self .prediction_provider_arrow  =  None 
325+         self .prediction_provider_normal  =  None 
326+         self .prediction_provider  =  None 
327+ 
328+         # set model to use non-arrow by default, as this requires no dataset information 
329+         self ._set_nonarrow ()
330+ 
331+     def  _set_arrow (self , paradigm_input : PredictionInput ):
332+         """ 
333+         Ready the model for arrow-based prediction communication. 
334+ 
335+         Parameters 
336+         ---------- 
337+         paradigm_input: A single :obj:`PredictionInput` by which to establish the arrow schema. 
338+         All subsequent :obj:`PredictionInput`s communicated must have this schema. 
339+         """ 
340+         if  self .disable_arrow :
341+             self ._set_nonarrow ()
342+         else :
343+             if  self .prediction_provider_arrow  is  None :
344+                 raw_ppa  =  self ._get_arrow_prediction_provider ()
345+                 self .prediction_provider_arrow  =  raw_ppa .get_as_prediction_provider (
346+                     paradigm_input 
332347                )
348+             self .prediction_provider  =  self .prediction_provider_arrow 
349+ 
350+     def  _set_nonarrow (self ):
351+         """ 
352+         Ready the model for non-arrow-prediction communication. 
353+         """ 
354+         if  self .prediction_provider_normal  is  None :
355+             self .prediction_provider_normal  =  self ._get_nonarrow_prediction_provider ()
356+         self .prediction_provider  =  self .prediction_provider_normal 
357+ 
358+     def  _get_arrow_prediction_provider (self ):
359+         if  not  self .dataframe_input :
360+             ppa  =  PredictionProviderArrow (
361+                 lambda  x : self ._cast_outputs_to_dataframe (self .predict_fun (x .values ))
362+             )
333363        else :
334-             self .prediction_provider_arrow  =  None 
335-             if  dataframe_input :
336-                 self .prediction_provider  =  PredictionProvider (
337-                     lambda  x : self ._cast_outputs (
338-                         predict_fun (prediction_object_to_pandas (x ))
339-                     )
364+             ppa  =  PredictionProviderArrow (
365+                 lambda  x : self ._cast_outputs_to_dataframe (self .predict_fun (x ))
366+             )
367+         return  ppa 
368+ 
369+     def  _get_nonarrow_prediction_provider (self ):
370+         if  self .dataframe_input :
371+             ppn  =  PredictionProvider (
372+                 lambda  x : self ._cast_outputs (
373+                     self .predict_fun (prediction_object_to_pandas (x ))
340374                )
341-             else : 
342-                  self . prediction_provider   =   PredictionProvider ( 
343-                      lambda   x :  self . _cast_outputs (
344-                          predict_fun ( prediction_object_to_numpy ( x )) 
345-                     )
375+             ) 
376+         else : 
377+             ppn   =   PredictionProvider (
378+                 lambda   x :  self . _cast_outputs ( 
379+                     self . predict_fun ( prediction_object_to_numpy ( x ) )
346380                )
381+             )
382+         return  ppn 
347383
348384    def  _cast_outputs (self , output_array ):
349385        return  df_to_prediction_object (
@@ -388,12 +424,8 @@ def predictAsync(self, inputs: List[PredictionInput]) -> CompletableFuture:
388424        :obj:`CompletableFuture` 
389425            A Java :obj:`CompletableFuture` containing the model outputs. 
390426        """ 
391-         if  self .arrow  and  self .prediction_provider  is  None :
392-             self .prediction_provider  =  (
393-                 self .prediction_provider_arrow .get_as_prediction_provider (inputs [0 ])
394-             )
395-         out  =  self .prediction_provider .predictAsync (inputs )
396-         return  out 
427+ 
428+         return  self .prediction_provider .predictAsync (inputs )
397429
398430    def  __call__ (self , inputs ):
399431        """ 
@@ -405,6 +437,51 @@ def __call__(self, inputs):
405437        """ 
406438        return  self .predict_fun (inputs )
407439
440+     class  ArrowTransmission :
441+         """ 
442+         Context class to ensure all predictAsync calls within the context use arrow. 
443+ 
444+         Parameters 
445+         ---------- 
446+         model: The TrustyAI :obj:`Model` or PredictionProvider 
447+         paradigm_input: A single :obj:`PredictionInput` by which to establish the arrow schema. 
448+          All subsequent :obj:`PredictionInput`s communicated must have this schema. 
449+         """ 
450+ 
451+         def  __init__ (self , model , paradigm_input : OneInputUnionType ):
452+             self .model  =  model 
453+             self .model_is_python  =  isinstance (model , Model )
454+             self .paradigm_input  =  one_input_convert (paradigm_input )
455+             self .previous_model_state  =  None 
456+ 
457+         def  __enter__ (self ):
458+             if  self .model_is_python :
459+                 self .previous_model_state  =  self .model .prediction_provider 
460+                 self .model ._set_arrow (self .paradigm_input )
461+ 
462+         def  __exit__ (self , exit_type , value , traceback ):
463+             if  self .model_is_python :
464+                 self .model .prediction_provider  =  self .previous_model_state 
465+ 
466+     class  NonArrowTransmission :
467+         """ 
468+         Context class to ensure all predictAsync calls within the context DO NOT use arrow. 
469+         """ 
470+ 
471+         def  __init__ (self , model ):
472+             self .model  =  model 
473+             self .model_is_python  =  isinstance (model , Model )
474+             self .previous_model_state  =  None 
475+ 
476+         def  __enter__ (self ):
477+             if  self .model_is_python :
478+                 self .previous_model_state  =  self .model .prediction_provider 
479+                 self .model ._set_nonarrow ()
480+ 
481+         def  __exit__ (self , exit_type , value , traceback ):
482+             if  self .model_is_python :
483+                 self .model .prediction_provider  =  self .previous_model_state 
484+ 
408485
409486@_jcustomizer .JImplementationFor ("org.kie.trustyai.explainability.model.Output" ) 
410487# pylint: disable=no-member 
0 commit comments