@@ -453,14 +453,12 @@ def __init__(
453453
454454 # The estimator is checked against the class attribute for conformance.
455455 # This should only trigger if the user uses this class directly.
456- if (
457- self .estimator .__class__ == DecisionTreeClassifier
458- and self ._onedal_factory != onedal_RandomForestClassifier
456+ if self .estimator .__class__ == DecisionTreeClassifier and not issubclass (
457+ self ._onedal_factory , onedal_RandomForestClassifier
459458 ):
460459 self ._onedal_factory = onedal_RandomForestClassifier
461- elif (
462- self .estimator .__class__ == ExtraTreeClassifier
463- and self ._onedal_factory != onedal_ExtraTreesClassifier
460+ elif self .estimator .__class__ == ExtraTreeClassifier and not issubclass (
461+ self ._onedal_factory , onedal_ExtraTreesClassifier
464462 ):
465463 self ._onedal_factory = onedal_ExtraTreesClassifier
466464
@@ -843,14 +841,12 @@ def __init__(
843841
844842 # The splitter is checked against the class attribute for conformance
845843 # This should only trigger if the user uses this class directly.
846- if (
847- self .estimator .__class__ == DecisionTreeRegressor
848- and self ._onedal_factory != onedal_RandomForestRegressor
844+ if self .estimator .__class__ == DecisionTreeRegressor and not issubclass (
845+ self ._onedal_factory , onedal_RandomForestRegressor
849846 ):
850847 self ._onedal_factory = onedal_RandomForestRegressor
851- elif (
852- self .estimator .__class__ == ExtraTreeRegressor
853- and self ._onedal_factory != onedal_ExtraTreesRegressor
848+ elif self .estimator .__class__ == ExtraTreeRegressor and not issubclass (
849+ self ._onedal_factory , onedal_ExtraTreesRegressor
854850 ):
855851 self ._onedal_factory = onedal_ExtraTreesRegressor
856852
0 commit comments