@@ -24,9 +24,8 @@ def _get_states(
24
24
labels : torch .Tensor ,
25
25
predictions : torch .Tensor ,
26
26
weights : torch .Tensor ,
27
- num_candidates : torch .Tensor ,
28
27
) -> Dict [str , torch .Tensor ]:
29
- gauc_res = compute_gauc_3d (predictions , labels , num_candidates )
28
+ gauc_res = compute_gauc_3d (predictions , labels , weights )
30
29
return {
31
30
"auc_sum" : gauc_res ["auc_sum" ],
32
31
"num_samples" : gauc_res ["num_samples" ],
@@ -44,8 +43,8 @@ class GAUCMetricValueTest(unittest.TestCase):
44
43
def setUp (self ) -> None :
45
44
self .predictions = {"DefaultTask" : None }
46
45
self .labels = {"DefaultTask" : None }
46
+ self .weights = {"DefaultTask" : None }
47
47
self .num_candidates = None
48
- self .weights = None
49
48
self .batches = {
50
49
"predictions" : self .predictions ,
51
50
"labels" : self .labels ,
@@ -62,13 +61,13 @@ def setUp(self) -> None:
62
61
def test_calc_gauc_simple (self ) -> None :
63
62
self .predictions ["DefaultTask" ] = torch .tensor ([[0.9 , 0.8 , 0.7 , 0.6 , 0.5 ]])
64
63
self .labels ["DefaultTask" ] = torch .tensor ([[1 , 0 , 1 , 1 , 0 ]])
64
+ self .weights ["DefaultTask" ] = torch .tensor ([[1 , 1 , 1 , 1 , 1 ]])
65
65
self .num_candidates = torch .tensor ([3 , 2 ])
66
- self .weights = None
67
66
self .batches = {
68
67
"predictions" : self .predictions ,
69
68
"labels" : self .labels ,
70
69
"num_candidates" : self .num_candidates ,
71
- "weights" : None ,
70
+ "weights" : self . weights ,
72
71
}
73
72
74
73
expected_gauc = torch .tensor ([0.75 ], dtype = torch .double )
@@ -97,13 +96,13 @@ def test_calc_gauc_hard(self) -> None:
97
96
[[0.3 , 0.9 , 0.1 , 0.8 , 0.2 , 0.8 , 0.7 , 0.6 , 0.5 , 0.5 ]]
98
97
)
99
98
self .labels ["DefaultTask" ] = torch .tensor ([[1 , 1 , 1 , 0 , 0 , 1 , 0 , 1 , 1 , 0 ]])
99
+ self .weights ["DefaultTask" ] = torch .tensor ([[1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]])
100
100
self .num_candidates = torch .tensor ([2 , 3 , 3 , 2 ])
101
- self .weights = None
102
101
self .batches = {
103
102
"predictions" : self .predictions ,
104
103
"labels" : self .labels ,
105
104
"num_candidates" : self .num_candidates ,
106
- "weights" : None ,
105
+ "weights" : self . weights ,
107
106
}
108
107
109
108
expected_gauc = torch .tensor ([0.25 ], dtype = torch .double )
@@ -130,8 +129,8 @@ def test_calc_gauc_hard(self) -> None:
130
129
def test_calc_gauc_all_0_labels (self ) -> None :
131
130
self .predictions ["DefaultTask" ] = torch .tensor ([[0.9 , 0.8 , 0.7 , 0.6 , 0.5 ]])
132
131
self .labels ["DefaultTask" ] = torch .tensor ([[0 , 0 , 0 , 0 , 0 ]])
132
+ self .weights ["DefaultTask" ] = torch .tensor ([[1 , 1 , 1 , 1 , 1 ]])
133
133
self .num_candidates = torch .tensor ([3 , 2 ])
134
- self .weights = None
135
134
self .batches = {
136
135
"predictions" : self .predictions ,
137
136
"labels" : self .labels ,
@@ -163,8 +162,8 @@ def test_calc_gauc_all_0_labels(self) -> None:
163
162
def test_calc_gauc_all_1_labels (self ) -> None :
164
163
self .predictions ["DefaultTask" ] = torch .tensor ([[0.9 , 0.8 , 0.7 , 0.6 , 0.5 ]])
165
164
self .labels ["DefaultTask" ] = torch .tensor ([[1 , 1 , 1 , 1 , 1 ]])
165
+ self .weights ["DefaultTask" ] = torch .tensor ([[1 , 1 , 1 , 1 , 1 ]])
166
166
self .num_candidates = torch .tensor ([3 , 2 ])
167
- self .weights = None
168
167
self .batches = {
169
168
"predictions" : self .predictions ,
170
169
"labels" : self .labels ,
@@ -196,6 +195,7 @@ def test_calc_gauc_all_1_labels(self) -> None:
196
195
def test_calc_gauc_identical_predictions (self ) -> None :
197
196
self .predictions ["DefaultTask" ] = torch .tensor ([[0.8 , 0.8 , 0.8 , 0.8 , 0.8 ]])
198
197
self .labels ["DefaultTask" ] = torch .tensor ([[1 , 1 , 0 , 1 , 0 ]])
198
+ self .weights ["DefaultTask" ] = torch .tensor ([[1 , 1 , 1 , 1 , 1 ]])
199
199
self .num_candidates = torch .tensor ([3 , 2 ])
200
200
self .weights = None
201
201
self .batches = {
@@ -225,3 +225,38 @@ def test_calc_gauc_identical_predictions(self) -> None:
225
225
actual_gauc , expected_gauc
226
226
)
227
227
)
228
+
229
+ def test_calc_gauc_weighted (self ) -> None :
230
+ self .predictions ["DefaultTask" ] = torch .tensor (
231
+ [[0.3 , 0.9 , 0.1 , 0.8 , 0.2 , 0.8 , 0.7 , 0.6 , 0.5 , 0.5 ]]
232
+ )
233
+ self .labels ["DefaultTask" ] = torch .tensor ([[1 , 1 , 1 , 0 , 0 , 1 , 0 , 1 , 1 , 0 ]])
234
+ self .weights ["DefaultTask" ] = torch .tensor ([[1 , 1 , 1 , 0 , 1 , 1 , 1 , 0 , 1 , 1 ]])
235
+ self .num_candidates = torch .tensor ([2 , 3 , 3 , 2 ])
236
+ self .batches = {
237
+ "predictions" : self .predictions ,
238
+ "labels" : self .labels ,
239
+ "num_candidates" : self .num_candidates ,
240
+ "weights" : self .weights ,
241
+ }
242
+
243
+ expected_gauc = torch .tensor ([0.5 ], dtype = torch .double )
244
+ expected_num_samples = torch .tensor ([2 ], dtype = torch .double )
245
+ self .gauc .update (** self .batches )
246
+ gauc_res = self .gauc .compute ()
247
+ actual_gauc , num_effective_samples = (
248
+ gauc_res ["gauc-DefaultTask|window_gauc" ],
249
+ gauc_res ["gauc-DefaultTask|window_gauc_num_samples" ],
250
+ )
251
+ if not torch .allclose (expected_num_samples , num_effective_samples ):
252
+ raise ValueError (
253
+ "actual num sample {} is not equal to expected num sample {}" .format (
254
+ num_effective_samples , expected_num_samples
255
+ )
256
+ )
257
+ if not torch .allclose (expected_gauc , actual_gauc ):
258
+ raise ValueError (
259
+ "actual auc {} is not equal to expected auc {}" .format (
260
+ actual_gauc , expected_gauc
261
+ )
262
+ )
0 commit comments