@@ -728,3 +728,62 @@ def Summary(self, name):
728728 def _CreateSummary (self , name ):
729729 """Returns a tf.Summary for this metric."""
730730 raise NotImplementedError ()
731+
732+
733+ class GroupPairAUCMetric (AUCMetric ):
734+ """Compute the AUC score for all pairs extracted from each group of items.
735+
736+ For each group of items, the metric extracts all pairs with different
737+ target values. For each pair (i, j), the metric computes the binary
738+ classification AUC where the `label = 1 if target[i] > target[j] else 0` and
739+ `prob = sigmoid(logits[i] - logits[j])`.
740+
741+ To prevent generating pairs across groups, an additional arg `group_ids` is
742+ required, which is a list of ints that specifies the group_id of each item.
743+
744+ In addition, in order to achieve streaming computation, items from the same
745+ group need to form continuous chunks,
746+ e.g. group_ids = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2].
747+
748+ In the case of [0, 0, 1, 1, 1, 0, 0, 2, 2, 2, 2], the second chunk of 0s will
749+ be treated as a separate 3rd group rather than part of the 1st group.
750+ """
751+
752+ def UpdateRaw (self , group_ids , target , logits , weight = None ):
753+ """Updates the metrics.
754+
755+ Args:
756+ group_ids: An array to specify the group identity.
757+ target: An array to specify the groundtruth float values.
758+ logits: An array to specify the raw prediction logits.
759+ weight: An array to specify the sample weight for the auc computation.
760+ """
761+
762+ assert self ._samples <= 0
763+
764+ sigmoid = lambda x : 1.0 / (1.0 + np .exp (- x ))
765+
766+ def _ProcessChunk (s , e ):
767+ for i in range (s , e ):
768+ for j in range (i + 1 , e ):
769+ if target [i ] != target [j ]:
770+ pair_label = 1 if target [i ] > target [j ] else 0
771+ pair_prob = sigmoid (logits [i ] - logits [j ])
772+ self ._label .append (pair_label )
773+ self ._prob .append (pair_prob )
774+ if weight :
775+ self ._weight .append (min (1.0 , weight [i ] + weight [j ]))
776+ else :
777+ self ._weight .append (1.0 )
778+
779+ s , e = 0 , 1
780+ while e <= len (target ):
781+ # Find the end of a chunk
782+ if e == len (target ) or group_ids [e ] != group_ids [s ]:
783+ # Process the current chunk [s:e]
784+ _ProcessChunk (s , e )
785+
786+ # Start a new chunk by setting `s` to `e`
787+ s = e
788+ # Increment `e` by 1.
789+ e += 1
0 commit comments