Skip to content

Commit 2ae0b83

Browse files
committed
Move _q_tau from AADForest to AADEvaluator
q_tau is Hinge-loss specific value
1 parent 861881b commit 2ae0b83

File tree

2 files changed

+14
-13
lines changed

2 files changed

+14
-13
lines changed

src/coniferest/aadforest.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,20 @@
1414
class AADEvaluator(ConiferestEvaluator):
1515
def __init__(self, aad):
1616
super(AADEvaluator, self).__init__(aad, map_value=aad.map_value)
17+
self.budget = aad.budget
1718
self.weights = np.full(shape=(self.n_leaves,), fill_value=np.reciprocal(np.sqrt(self.n_leaves)))
1819

20+
def _q_tau(self, scores):
21+
if isinstance(self.budget, int):
22+
if self.budget >= len(scores):
23+
return np.max(scores)
24+
25+
return np.partition(scores, self.budget)[self.budget]
26+
elif isinstance(self.budget, float):
27+
return np.quantile(scores, self.budget)
28+
29+
raise ValueError("self.budget must be an int or float")
30+
1931
def score_samples(self, x, weights=None):
2032
"""
2133
Perform the computations.
@@ -250,17 +262,6 @@ def _build_trees(self, data):
250262
self.trees = self.build_trees(data, self.n_trees)
251263
self.evaluator = AADEvaluator(self)
252264

253-
def _q_tau(self, scores):
254-
if isinstance(self.budget, int):
255-
if self.budget >= len(scores):
256-
return np.max(scores)
257-
258-
return np.partition(scores, self.budget)[self.budget]
259-
elif isinstance(self.budget, float):
260-
return np.quantile(scores, self.budget)
261-
262-
raise ValueError("self.budget must be an int or float")
263-
264265
def fit(self, data, labels=None):
265266
"""
266267
Build the trees with the data `data`.
@@ -324,7 +325,7 @@ def fit_known(self, data, known_data=None, known_labels=None):
324325
return self
325326

326327
scores = self.score_samples(data)
327-
q_tau = self._q_tau(scores)
328+
q_tau = self.evaluator._q_tau(scores)
328329

329330
anomaly_count = np.count_nonzero(known_labels == Label.ANOMALY)
330331
nominal_count = np.count_nonzero(known_labels == Label.REGULAR)

tests/test_aadforest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def test_benchmark_loss_gradient(n_samples, n_trees, n_jobs, benchmark):
113113
anomaly_count = np.count_nonzero(known_labels == Label.ANOMALY)
114114
nominal_count = np.count_nonzero(known_labels == Label.REGULAR)
115115
scores = forest.score_samples(data)
116-
q_tau = forest._q_tau(scores)
116+
q_tau = forest.evaluator._q_tau(scores)
117117

118118
benchmark(
119119
forest.evaluator.loss_gradient,

0 commit comments

Comments
 (0)