22# pylint: disable = unused-import, wrong-import-order
33# pylint: disable = consider-using-f-string
44"""General model classes"""
5+ import logging
6+ import traceback
57import uuid as _uuid
68from typing import List , Optional , Union , Callable
79import pandas as pd
@@ -322,7 +324,8 @@ def __init__(self, predict_fun, **kwargs):
322324 transfer between Java and Python. If false, Arrow will be automatically used in
323325 situations where it is advantageous to do so.
324326 """
325- self .predict_fun = predict_fun
327+
328+ self .predict_fun = self ._error_catcher (predict_fun )
326329 self .kwargs = kwargs
327330
328331 self .prediction_provider_arrow = None
@@ -332,6 +335,24 @@ def __init__(self, predict_fun, **kwargs):
332335 # set model to use non-arrow by default, as this requires no dataset information
333336 self ._set_nonarrow ()
334337
338+ def _error_catcher (self , predict_fun ):
339+ """Wrapper for predict function to capture errors to Python logger before the JVM dies"""
340+
341+ def wrapper (x ):
342+ try :
343+ return predict_fun (x )
344+ except Exception as e :
345+ logging .error (
346+ " Fatal runtime error within the `predict_fun` supplied to trustyai.Model"
347+ )
348+ logging .error (
349+ " The error message has been captured and reproduced below:"
350+ )
351+ logging .error (" %s" , traceback .format_exc ())
352+ raise e
353+
354+ return wrapper
355+
335356 @property
336357 def dataframe_input (self ):
337358 """Get dataframe_input kwarg value"""
@@ -483,7 +504,7 @@ def __enter__(self):
483504 self .previous_model_state = self .model .prediction_provider
484505 self .model ._set_arrow (self .paradigm_input )
485506
486- def __exit__ (self , exit_type , value , traceback ):
507+ def __exit__ (self , exit_type , value , tb ):
487508 if self .model_is_python :
488509 self .model .prediction_provider = self .previous_model_state
489510
@@ -502,7 +523,7 @@ def __enter__(self):
502523 self .previous_model_state = self .model .prediction_provider
503524 self .model ._set_nonarrow ()
504525
505- def __exit__ (self , exit_type , value , traceback ):
526+ def __exit__ (self , exit_type , value , tb ):
506527 if self .model_is_python :
507528 self .model .prediction_provider = self .previous_model_state
508529
0 commit comments