44from torchmetrics import Metric
55
66from torch_points3d .metrics .base_tracker import BaseTracker
7-
8-
9- def compute_average_intersection_union (confusion_matrix : torch .Tensor , missing_as_one : bool = False ) -> torch .Tensor :
10- """
11- compute intersection over union on average from confusion matrix
12- Parameters
13- Parameters
14- ----------
15- confusion_matrix: torch.Tensor
16- square matrix
17- missing_as_one: bool, default: False
18- """
19-
20- values , existing_classes_mask = compute_intersection_union_per_class (confusion_matrix , return_existing_mask = True )
21- if torch .sum (existing_classes_mask ) == 0 :
22- return torch .sum (existing_classes_mask )
23- if missing_as_one :
24- values [~ existing_classes_mask ] = 1
25- existing_classes_mask [:] = True
26- return torch .sum (values [existing_classes_mask ]) / torch .sum (existing_classes_mask )
27-
28-
29- def compute_mean_class_accuracy (confusion_matrix : torch .Tensor ) -> torch .Tensor :
30- """
31- compute intersection over union on average from confusion matrix
32-
33- Parameters
34- ----------
35- confusion_matrix: torch.Tensor
36- square matrix
37- """
38- total_gts = confusion_matrix .sum (1 )
39- labels_presents = torch .where (total_gts > 0 )[0 ]
40- if len (labels_presents ) == 0 :
41- return total_gts [0 ]
42- ones = torch .ones_like (total_gts )
43- max_ones_total_gts = torch .cat ([total_gts [None , :], ones [None , :]], 0 ).max (0 )[0 ]
44- re = (torch .diagonal (confusion_matrix )[labels_presents ].float () / max_ones_total_gts [labels_presents ]).sum ()
45- return re / float (len (labels_presents ))
46-
47-
48- def compute_overall_accuracy (confusion_matrix : torch .Tensor ) -> Union [int , torch .Tensor ]:
49- """
50- compute overall accuracy from confusion matrix
51-
52- Parameters
53- ----------
54- confusion_matrix: torch.Tensor
55- square matrix
56- """
57- all_values = confusion_matrix .sum ()
58- if all_values == 0 :
59- return 0
60- matrix_diagonal = torch .trace (confusion_matrix )
61- return matrix_diagonal .float () / all_values
62-
63-
64- def compute_intersection_union_per_class (
65- confusion_matrix : torch .Tensor , return_existing_mask : bool = False , eps : float = 1e-8
66- ) -> Tuple [torch .Tensor , Optional [torch .Tensor ]]:
67- """
68- compute intersection over union per class from confusion matrix
69-
70- Parameters
71- ----------
72- confusion_matrix: torch.Tensor
73- square matrix
74- """
75-
76- TP_plus_FN = confusion_matrix .sum (0 )
77- TP_plus_FP = confusion_matrix .sum (1 )
78- TP = torch .diagonal (confusion_matrix )
79- union = TP_plus_FN + TP_plus_FP - TP
80- iou = eps + TP / (union + eps )
81- existing_class_mask = union > 1e-3
82- if return_existing_mask :
83- return iou , existing_class_mask
84- else :
85- return iou , None
7+ import torch_points3d .metrics .segmentation .metrics as mt
868
879
8810class SegmentationTracker (BaseTracker ):
@@ -104,10 +26,10 @@ def __init__(
10426 self .eps = eps
10527
10628 def compute_metrics_from_cm (self , matrix : torch .Tensor ) -> Dict [str , Any ]:
107- acc = compute_overall_accuracy (matrix )
108- macc = compute_mean_class_accuracy (matrix )
109- miou = compute_average_intersection_union (matrix )
110- iou_per_class , _ = compute_intersection_union_per_class (matrix , eps = self .eps )
29+ acc = mt . compute_overall_accuracy (matrix )
30+ macc = mt . compute_mean_class_accuracy (matrix )
31+ miou = mt . compute_average_intersection_union (matrix )
32+ iou_per_class , _ = mt . compute_intersection_union_per_class (matrix , eps = self .eps )
11133 iou_per_class_dict = {f"{ self .stage } _iou_class_{ i } " : (100 * v ) for i , v in enumerate (iou_per_class )}
11234 res = {
11335 "{}_acc" .format (self .stage ): 100 * acc ,
0 commit comments