Skip to content

Commit 2b74f34

Browse files
committed
[WIP] Initial implementation for AADCrossEntropyEvaluator
1 parent b613b66 commit 2b74f34

File tree

1 file changed

+80
-0
lines changed

1 file changed

+80
-0
lines changed

src/coniferest/aadforest.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import numpy as np
66
from scipy.optimize import minimize
7+
from scipy.special import log_expit, expit
78

89
from .calc_trees import calc_paths_sum, calc_paths_sum_transpose # noqa
910
from .coniferest import Coniferest, ConiferestEvaluator
@@ -31,6 +32,85 @@ def fit_known(self, data, known_data, known_labels):
3132
raise NotImplementedError()
3233

3334

35+
class AADCrossEntropyEvaluator(AADEvaluator):
36+
def __init__(self, aad):
37+
super(AADCrossEntropyEvaluator, self).__init__(aad)
38+
self.weights = np.ones(shape=(self.n_leaves,))
39+
self.bias = 0.0 # Not sure about 0.0
40+
41+
def score_samples(self, x, weights=None):
42+
# Anomaly score is a probability of being REGULAR data.
43+
44+
if not x.flags["C_CONTIGUOUS"]:
45+
x = np.ascontiguousarray(x)
46+
47+
if weights is None:
48+
weights = self.weights
49+
50+
return expit(calc_paths_sum(
51+
self.selectors,
52+
self.node_offsets,
53+
x,
54+
weights,
55+
num_threads=self.num_threads,
56+
batch_size=self.get_batch_size(self.n_trees),
57+
) + self.bias)
58+
59+
def loss(
60+
self,
61+
weights,
62+
known_data,
63+
known_labels):
64+
65+
v = calc_paths_sum(
66+
self.selectors,
67+
self.node_offsets,
68+
known_data,
69+
weights[1:],
70+
num_threads=self.num_threads,
71+
batch_size=self.get_batch_size(self.n_trees),
72+
) + weights[0]
73+
74+
return -np.sum(log_expit(known_labels * v))
75+
76+
def loss_gradient(
77+
self,
78+
weights,
79+
known_data,
80+
known_labels):
81+
82+
v = calc_paths_sum(
83+
self.selectors,
84+
self.node_offsets,
85+
known_data,
86+
weights[1:],
87+
num_threads=self.num_threads,
88+
batch_size=self.get_batch_size(self.n_trees),
89+
) + weights[0]
90+
91+
dloss_dv = -known_labels * expit(-known_labels * v)
92+
dloss_dbias = np.sum(dloss_dv)
93+
dloss_dweights = calc_paths_sum_transpose(
94+
self.selectors,
95+
self.node_offsets,
96+
self.leaf_offsets,
97+
known_data,
98+
dloss_dv,
99+
num_threads=self.num_threads,
100+
batch_size=self.get_batch_size(len(known_data)),
101+
)
102+
103+
return np.concatenate([[dloss_dbias], dloss_dweights])
104+
105+
def loss_hessian(
106+
self,
107+
weights,
108+
vector,
109+
known_data,
110+
known_labels):
111+
pass
112+
113+
34114
class AADHingeEvaluator(AADEvaluator):
35115
def __init__(self, aad):
36116
super(AADHingeEvaluator, self).__init__(aad)

0 commit comments

Comments
 (0)