Skip to content

Commit b1b7e01

Browse files
xinyuanzzzfacebook-github-bot
authored andcommitted
Fix GAUC not calculated with weights (#2895)
Summary: Pull Request resolved: #2895 The gAUC score is lower than expected e.g. https://fburl.com/mlhub/vljz497c. In ig, if a label presence is false, the corresponding weight is set to 0. It should not be considered when calculating gAUC. Reviewed By: yunjiangster Differential Revision: D73231152 fbshipit-source-id: 3a83269948db27341cd8b6ad5d5f7b553195aa75
1 parent 6dc7c16 commit b1b7e01

File tree

2 files changed

+70
-35
lines changed

2 files changed

+70
-35
lines changed

torchrec/metrics/gauc.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
def compute_gauc_3d(
2525
predictions: torch.Tensor,
2626
labels: torch.Tensor,
27-
num_candidates: torch.Tensor,
27+
weights: torch.Tensor,
2828
) -> Dict[str, torch.Tensor]:
2929
"""Both predictions and labels are 3-d tensors in shape [n_task, n_group, n_sample]."""
3030

@@ -34,7 +34,7 @@ def compute_gauc_3d(
3434
pre_arange = torch.arange(max_len, device=predictions.device)
3535

3636
with record_function("## gauc_argsort ##"):
37-
sorted_indices = torch.argsort(predictions, descending=True, dim=-1)
37+
sorted_indices = torch.argsort(predictions, dim=-1)
3838
task_indices = (
3939
pre_arange[:n_task][:, None, None]
4040
.expand(n_task, n_group, n_sample)
@@ -51,28 +51,26 @@ def compute_gauc_3d(
5151
sorted_labels = labels[task_indices, group_indices, sample_indices].view(
5252
n_task, n_group, n_sample
5353
)
54+
sorted_weights = weights[task_indices, group_indices, sample_indices].view(
55+
n_task, n_group, n_sample
56+
)
5457

5558
with record_function("## gauc_calculation ##"):
56-
num_sample = num_candidates[None, :].expand(n_task, n_group)
57-
# Count number of padding zeros.
58-
num_zeros = (n_sample - num_candidates)[None, :, None].expand(
59-
n_task, n_group, n_sample
60-
) # [n_task, n_group, n_sample]
61-
# This assumes the labels are binary.
62-
num_zeros = (sorted_labels != 0) * num_zeros
63-
rank = torch.flip(pre_arange[:n_sample] + 1, [0])[None, None, :].expand(
64-
n_task, n_group, n_sample
65-
)
66-
positive_rank = sorted_labels * rank - num_zeros # [n_task, n_group, n_sample]
67-
num_positive = sorted_labels.sum(-1) # [n_task, n_group]
59+
pos_mask = sorted_labels
60+
neg_mask = 1 - sorted_labels
6861

69-
# AUC is calcuated as (sum{positive_ranks} - num{positive_pairs}) /
70-
# (num_positive * num_negative).
71-
numerator = torch.sum(positive_rank, -1) - (
72-
num_positive * (num_positive + 1) / 2
73-
)
74-
denominator = num_positive * (num_sample - num_positive)
75-
auc = numerator / (denominator + 1e-10) # [n_task, n_group]
62+
# cumulative negative *weight* that appear **before** each position
63+
cum_neg_weight = torch.cumsum(sorted_weights * neg_mask, dim=-1)
64+
65+
# contribution of every positive example: w_pos * (sum w_neg ranked lower)
66+
contrib = pos_mask * sorted_weights * cum_neg_weight
67+
numerator = contrib.sum(-1) # [n_task, n_group]
68+
69+
w_pos = (pos_mask * sorted_weights).sum(-1) # [n_task, n_group]
70+
w_neg = (neg_mask * sorted_weights).sum(-1) # [n_task, n_group]
71+
denominator = w_pos * w_neg
72+
73+
auc = numerator / (denominator + 1e-10)
7674

7775
# Skip identical prediction sessions.
7876
identical_prediction_mask = ~(
@@ -85,7 +83,7 @@ def compute_gauc_3d(
8583
)
8684
)
8785
# Skip identical label(all 0s/1s) sessions.
88-
identical_label_mask = (num_positive >= 1) * (num_positive < num_sample)
86+
identical_label_mask = (w_pos > 0) & (w_neg > 0)
8987
auc_mask = identical_label_mask * identical_prediction_mask
9088
auc *= auc_mask
9189
num_effective_samples = auc_mask.sum(-1) # [n_task]
@@ -104,23 +102,25 @@ def to_3d(
104102
def get_auc_states(
105103
labels: torch.Tensor,
106104
predictions: torch.Tensor,
107-
weights: Optional[torch.Tensor],
105+
weights: torch.Tensor,
108106
num_candidates: torch.Tensor,
109107
) -> Dict[str, torch.Tensor]:
110108

111109
# predictions, labels: [n_task, n_sample]
112110
max_length = int(num_candidates.max().item())
113111
predictions_perm = predictions.permute(1, 0)
114112
labels_perm = labels.permute(1, 0)
113+
weights_perm = weights.permute(1, 0)
115114
predictions_3d = to_3d(predictions_perm, num_candidates, max_length).permute(
116115
2, 0, 1
117116
)
118117
labels_3d = to_3d(labels_perm, num_candidates, max_length).permute(2, 0, 1)
118+
weights_3d = to_3d(weights_perm, num_candidates, max_length).permute(2, 0, 1)
119119

120120
return compute_gauc_3d(
121121
predictions_3d,
122122
labels_3d,
123-
num_candidates,
123+
weights_3d,
124124
)
125125

126126

@@ -175,9 +175,9 @@ def update(
175175
num_candidates: torch.Tensor,
176176
**kwargs: Dict[str, Any],
177177
) -> None:
178-
if predictions is None or labels is None:
178+
if predictions is None or weights is None:
179179
raise RecMetricException(
180-
"Inputs 'predictions' and 'labels' should not be None for GAUCMetricComputation update"
180+
"Inputs 'predictions' and 'weights' should not be None for GAUCMetricComputation update"
181181
)
182182

183183
states = get_auc_states(labels, predictions, weights, num_candidates)

torchrec/metrics/tests/test_gauc.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,8 @@ def _get_states(
2424
labels: torch.Tensor,
2525
predictions: torch.Tensor,
2626
weights: torch.Tensor,
27-
num_candidates: torch.Tensor,
2827
) -> Dict[str, torch.Tensor]:
29-
gauc_res = compute_gauc_3d(predictions, labels, num_candidates)
28+
gauc_res = compute_gauc_3d(predictions, labels, weights)
3029
return {
3130
"auc_sum": gauc_res["auc_sum"],
3231
"num_samples": gauc_res["num_samples"],
@@ -44,8 +43,8 @@ class GAUCMetricValueTest(unittest.TestCase):
4443
def setUp(self) -> None:
4544
self.predictions = {"DefaultTask": None}
4645
self.labels = {"DefaultTask": None}
46+
self.weights = {"DefaultTask": None}
4747
self.num_candidates = None
48-
self.weights = None
4948
self.batches = {
5049
"predictions": self.predictions,
5150
"labels": self.labels,
@@ -62,13 +61,13 @@ def setUp(self) -> None:
6261
def test_calc_gauc_simple(self) -> None:
6362
self.predictions["DefaultTask"] = torch.tensor([[0.9, 0.8, 0.7, 0.6, 0.5]])
6463
self.labels["DefaultTask"] = torch.tensor([[1, 0, 1, 1, 0]])
64+
self.weights["DefaultTask"] = torch.tensor([[1, 1, 1, 1, 1]])
6565
self.num_candidates = torch.tensor([3, 2])
66-
self.weights = None
6766
self.batches = {
6867
"predictions": self.predictions,
6968
"labels": self.labels,
7069
"num_candidates": self.num_candidates,
71-
"weights": None,
70+
"weights": self.weights,
7271
}
7372

7473
expected_gauc = torch.tensor([0.75], dtype=torch.double)
@@ -97,13 +96,13 @@ def test_calc_gauc_hard(self) -> None:
9796
[[0.3, 0.9, 0.1, 0.8, 0.2, 0.8, 0.7, 0.6, 0.5, 0.5]]
9897
)
9998
self.labels["DefaultTask"] = torch.tensor([[1, 1, 1, 0, 0, 1, 0, 1, 1, 0]])
99+
self.weights["DefaultTask"] = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
100100
self.num_candidates = torch.tensor([2, 3, 3, 2])
101-
self.weights = None
102101
self.batches = {
103102
"predictions": self.predictions,
104103
"labels": self.labels,
105104
"num_candidates": self.num_candidates,
106-
"weights": None,
105+
"weights": self.weights,
107106
}
108107

109108
expected_gauc = torch.tensor([0.25], dtype=torch.double)
@@ -130,8 +129,8 @@ def test_calc_gauc_hard(self) -> None:
130129
def test_calc_gauc_all_0_labels(self) -> None:
131130
self.predictions["DefaultTask"] = torch.tensor([[0.9, 0.8, 0.7, 0.6, 0.5]])
132131
self.labels["DefaultTask"] = torch.tensor([[0, 0, 0, 0, 0]])
132+
self.weights["DefaultTask"] = torch.tensor([[1, 1, 1, 1, 1]])
133133
self.num_candidates = torch.tensor([3, 2])
134-
self.weights = None
135134
self.batches = {
136135
"predictions": self.predictions,
137136
"labels": self.labels,
@@ -163,8 +162,8 @@ def test_calc_gauc_all_0_labels(self) -> None:
163162
def test_calc_gauc_all_1_labels(self) -> None:
164163
self.predictions["DefaultTask"] = torch.tensor([[0.9, 0.8, 0.7, 0.6, 0.5]])
165164
self.labels["DefaultTask"] = torch.tensor([[1, 1, 1, 1, 1]])
165+
self.weights["DefaultTask"] = torch.tensor([[1, 1, 1, 1, 1]])
166166
self.num_candidates = torch.tensor([3, 2])
167-
self.weights = None
168167
self.batches = {
169168
"predictions": self.predictions,
170169
"labels": self.labels,
@@ -196,6 +195,7 @@ def test_calc_gauc_all_1_labels(self) -> None:
196195
def test_calc_gauc_identical_predictions(self) -> None:
197196
self.predictions["DefaultTask"] = torch.tensor([[0.8, 0.8, 0.8, 0.8, 0.8]])
198197
self.labels["DefaultTask"] = torch.tensor([[1, 1, 0, 1, 0]])
198+
self.weights["DefaultTask"] = torch.tensor([[1, 1, 1, 1, 1]])
199199
self.num_candidates = torch.tensor([3, 2])
200200
self.weights = None
201201
self.batches = {
@@ -225,3 +225,38 @@ def test_calc_gauc_identical_predictions(self) -> None:
225225
actual_gauc, expected_gauc
226226
)
227227
)
228+
229+
def test_calc_gauc_weighted(self) -> None:
230+
self.predictions["DefaultTask"] = torch.tensor(
231+
[[0.3, 0.9, 0.1, 0.8, 0.2, 0.8, 0.7, 0.6, 0.5, 0.5]]
232+
)
233+
self.labels["DefaultTask"] = torch.tensor([[1, 1, 1, 0, 0, 1, 0, 1, 1, 0]])
234+
self.weights["DefaultTask"] = torch.tensor([[1, 1, 1, 0, 1, 1, 1, 0, 1, 1]])
235+
self.num_candidates = torch.tensor([2, 3, 3, 2])
236+
self.batches = {
237+
"predictions": self.predictions,
238+
"labels": self.labels,
239+
"num_candidates": self.num_candidates,
240+
"weights": self.weights,
241+
}
242+
243+
expected_gauc = torch.tensor([0.5], dtype=torch.double)
244+
expected_num_samples = torch.tensor([2], dtype=torch.double)
245+
self.gauc.update(**self.batches)
246+
gauc_res = self.gauc.compute()
247+
actual_gauc, num_effective_samples = (
248+
gauc_res["gauc-DefaultTask|window_gauc"],
249+
gauc_res["gauc-DefaultTask|window_gauc_num_samples"],
250+
)
251+
if not torch.allclose(expected_num_samples, num_effective_samples):
252+
raise ValueError(
253+
"actual num sample {} is not equal to expected num sample {}".format(
254+
num_effective_samples, expected_num_samples
255+
)
256+
)
257+
if not torch.allclose(expected_gauc, actual_gauc):
258+
raise ValueError(
259+
"actual auc {} is not equal to expected auc {}".format(
260+
actual_gauc, expected_gauc
261+
)
262+
)

0 commit comments

Comments
 (0)