Skip to content

Commit 861881b

Browse files
committed
[WIP] Introduce loss function argument for AADForest
1 parent b9ba72e commit 861881b

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

src/coniferest/aadforest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ class AADForest(Coniferest):
176176
map_value : ["const", "exponential", "linear", "reciprocal"] or callable, optional
177177
An function applied to the leaf depth before weighting. Possible
178178
meaning variants are: 1, 1-exp(-x), x, -1/x.
179+
180+
loss : ["hinge"], optional (default="hinge")
181+
Loss function used to optimize the leaf weights. The default is the hinge loss,
182+
as in the original paper.
179183
"""
180184

181185
def __init__(
@@ -190,6 +194,7 @@ def __init__(
190194
random_seed=None,
191195
sampletrees_per_batch=1 << 20,
192196
map_value=None,
197+
loss="hinge",
193198
):
194199
super().__init__(
195200
trees=[],
@@ -231,6 +236,13 @@ def __init__(
231236
else:
232237
raise ValueError(f"map_value is neither a callable nor one of {', '.join(MAP_VALUES.keys())}.")
233238

239+
LOSSES = ["hinge"]
240+
241+
if loss not in LOSSES:
242+
raise ValueError(f"loss is not one of {', '.join(LOSSES)}.")
243+
244+
self.loss = loss
245+
234246
self.evaluator = None
235247

236248
def _build_trees(self, data):

0 commit comments

Comments
 (0)