@@ -10,7 +10,8 @@ def __init__(self, name):
1010 self ._name = name
1111
1212 def _compute (self , y_pred , y_true ):
13- """Helper function for computing the metric. Subclasses should implement this.
13+ """Helper function for computing the metric. Subclasses should
14+ implement this.
1415
1516 Args:
1617 - y_pred (Tensor): Predicted targets or model output
@@ -21,7 +22,8 @@ def _compute(self, y_pred, y_true):
2122 return NotImplementedError
2223
2324 def worst (self , metrics ):
24- """Given a list/numpy array/Tensor of metrics, computes the worst-case metric.
25+ """Given a list/numpy array/Tensor of metrics, computes the worst-case
26+ metric.
2527
2628 Args:
2729 - metrics (Tensor, numpy array, or list): Metrics
@@ -34,33 +36,35 @@ def worst(self, metrics):
3436 def name (self ):
3537 """Metric name.
3638
37- Used to name the key in the results dictionaries returned by the metric.
39+ Used to name the key in the results dictionaries returned by the
40+ metric.
3841 """
3942 return self ._name
4043
4144 @property
4245 def agg_metric_field (self ):
43- """The name of the key in the results dictionary returned by Metric.compute().
46+ """The name of the key in the results dictionary returned by
47+ Metric.compute().
4448
45- This should correspond to the aggregate metric computed on all of y_pred and y_true, in
46- contrast to a group-wise evaluation.
49+ This should correspond to the aggregate metric computed on all
50+ of y_pred and y_true, in contrast to a group-wise evaluation.
4751 """
4852 return f'{ self .name } _all'
4953
5054 def group_metric_field (self , group_idx ):
51- """The name of the keys corresponding to individual group evaluations in the results
52- dictionary returned by Metric.compute_group_wise()."""
55+ """The name of the keys corresponding to individual group evaluations
56+ in the results dictionary returned by Metric.compute_group_wise()."""
5357 return f'{ self .name } _group:{ group_idx } '
5458
5559 @property
5660 def worst_group_metric_field (self ):
57- """The name of the keys corresponding to the worst-group metric in the results dictionary
58- returned by Metric.compute_group_wise()."""
61+ """The name of the keys corresponding to the worst-group metric in the
62+ results dictionary returned by Metric.compute_group_wise()."""
5963 return f'{ self .name } _wg'
6064
6165 def group_count_field (self , group_idx ):
62- """The name of the keys corresponding to each group's count in the results dictionary
63- returned by Metric.compute_group_wise()."""
66+ """The name of the keys corresponding to each group's count in the
67+ results dictionary returned by Metric.compute_group_wise()."""
6468 return f'count_group:{ group_idx } '
6569
6670 def compute (self , y_pred , y_true , return_dict = True ):
@@ -140,7 +144,8 @@ class ElementwiseMetric(Metric):
140144 """Averages."""
141145
142146 def _compute_element_wise (self , y_pred , y_true ):
143- """Helper for computing element-wise metric, implemented for each metric.
147+ """Helper for computing element-wise metric, implemented for each
148+ metric.
144149
145150 Args:
146151 - y_pred (Tensor): Predicted targets or model output
@@ -151,7 +156,8 @@ def _compute_element_wise(self, y_pred, y_true):
151156 raise NotImplementedError
152157
153158 def worst (self , metrics ):
154- """Given a list/numpy array/Tensor of metrics, computes the worst-case metric.
159+ """Given a list/numpy array/Tensor of metrics, computes the worst-case
160+ metric.
155161
156162 Args:
157163 - metrics (Tensor, numpy array, or list): Metrics
@@ -182,7 +188,8 @@ def _compute_group_wise(self, y_pred, y_true, g, n_groups):
182188
183189 @property
184190 def agg_metric_field (self ):
185- """The name of the key in the results dictionary returned by Metric.compute()."""
191+ """The name of the key in the results dictionary returned by
192+ Metric.compute()."""
186193 return f'{ self .name } _avg'
187194
188195 def compute_element_wise (self , y_pred , y_true , return_dict = True ):
0 commit comments