-
Notifications
You must be signed in to change notification settings - Fork 28
Description
Is your feature request related to a problem? Please describe.
I had the problem that in my prediction use case I wanted different estimators and estimator configurations on each of the levels. The problem was harder and therefor needed more capacity on "higher" hierarchy levels and could be lowered on lower ones
Describe the solution you'd like
I would like to parameterize the Hierarchical Model and pass in a factory function instead of a fixed estimator which is copied, my current hack to do this looks like this
ClassifierFactory = Callable[[int], BaseEstimator]
def estimator_based_on_depth(depth: int):
if depth == 0:
return RandomForestClassifier(n_estimators=100)
elif depth == 1:
return RandomForestClassifier(n_estimators=50)
else:
return RandomForestClassifier(n_estimators=10)
class CustomizableLocalClassifier(LocalClassifierPerParentNode):
def __init__(
self,
local_classifier: Union[BaseEstimator, ClassifierFactory]=None,
verbose=0,
edge_list=None,
replace_classifiers=True,
n_jobs=1,
calibration_method=None,
return_all_probabilities=False,
probability_combiner="multiply",
tmp_dir=None,
):
super().__init__(
local_classifier,
verbose,
edge_list,
replace_classifiers,
n_jobs,
calibration_method,
return_all_probabilities,
probability_combiner,
tmp_dir,
)
def _initialize_local_classifiers(self):
super()._initialize_local_classifiers()
local_classifiers = {}
nodes = self._get_parents()
for node in nodes:
if callable(self.local_classifier):
depth = len( node.split(self.separator_)) - 1
local_classifiers[node] = {"classifier": self.local_classifier(depth)}
else:
local_classifiers[node] = {"classifier": deepcopy(self.local_classifier_)}
nx.set_node_attributes(self.hierarchy_, local_classifiers)Describe alternatives you've considered
I checked if there is a simpler and modular way to achieve this but could not find a different one
Let me know if this is an interesting feature. Then I would add a PR for it
This of course could be extended to not only use the depth but the complete path for the model