Skip to content

Commit 545f74d

Browse files
author
humanpose1
committed
change based model
1 parent 043d8d7 commit 545f74d

File tree

6 files changed

+94
-96
lines changed

6 files changed

+94
-96
lines changed

torch_points3d/metrics/__init__.py

Whitespace-only changes.

torch_points3d/metrics/segmentation/__init__.py

Whitespace-only changes.
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import torch
2+
from typing import Dict, Optional, Tuple, Any, Union
3+
4+
5+
def compute_average_intersection_union(confusion_matrix: torch.Tensor, missing_as_one: bool = False) -> torch.Tensor:
6+
"""
7+
compute intersection over union on average from confusion matrix
8+
Parameters
9+
Parameters
10+
----------
11+
confusion_matrix: torch.Tensor
12+
square matrix
13+
missing_as_one: bool, default: False
14+
"""
15+
16+
values, existing_classes_mask = compute_intersection_union_per_class(confusion_matrix, return_existing_mask=True)
17+
if torch.sum(existing_classes_mask) == 0:
18+
return torch.sum(existing_classes_mask)
19+
if missing_as_one:
20+
values[~existing_classes_mask] = 1
21+
existing_classes_mask[:] = True
22+
return torch.sum(values[existing_classes_mask]) / torch.sum(existing_classes_mask)
23+
24+
25+
def compute_mean_class_accuracy(confusion_matrix: torch.Tensor) -> torch.Tensor:
26+
"""
27+
compute intersection over union on average from confusion matrix
28+
29+
Parameters
30+
----------
31+
confusion_matrix: torch.Tensor
32+
square matrix
33+
"""
34+
total_gts = confusion_matrix.sum(1)
35+
labels_presents = torch.where(total_gts > 0)[0]
36+
if len(labels_presents) == 0:
37+
return total_gts[0]
38+
ones = torch.ones_like(total_gts)
39+
max_ones_total_gts = torch.cat([total_gts[None, :], ones[None, :]], 0).max(0)[0]
40+
re = (torch.diagonal(confusion_matrix)[labels_presents].float() / max_ones_total_gts[labels_presents]).sum()
41+
return re / float(len(labels_presents))
42+
43+
44+
def compute_overall_accuracy(confusion_matrix: torch.Tensor) -> Union[int, torch.Tensor]:
45+
"""
46+
compute overall accuracy from confusion matrix
47+
48+
Parameters
49+
----------
50+
confusion_matrix: torch.Tensor
51+
square matrix
52+
"""
53+
all_values = confusion_matrix.sum()
54+
if all_values == 0:
55+
return 0
56+
matrix_diagonal = torch.trace(confusion_matrix)
57+
return matrix_diagonal.float() / all_values
58+
59+
60+
def compute_intersection_union_per_class(
61+
confusion_matrix: torch.Tensor, return_existing_mask: bool = False, eps: float = 1e-8
62+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
63+
"""
64+
compute intersection over union per class from confusion matrix
65+
66+
Parameters
67+
----------
68+
confusion_matrix: torch.Tensor
69+
square matrix
70+
"""
71+
72+
TP_plus_FN = confusion_matrix.sum(0)
73+
TP_plus_FP = confusion_matrix.sum(1)
74+
TP = torch.diagonal(confusion_matrix)
75+
union = TP_plus_FN + TP_plus_FP - TP
76+
iou = eps + TP / (union + eps)
77+
existing_class_mask = union > 1e-3
78+
if return_existing_mask:
79+
return iou, existing_class_mask
80+
else:
81+
return iou, None

torch_points3d/metrics/segmentation/segmentation_tracker.py

Lines changed: 5 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -4,85 +4,7 @@
44
from torchmetrics import Metric
55

66
from torch_points3d.metrics.base_tracker import BaseTracker
7-
8-
9-
def compute_average_intersection_union(confusion_matrix: torch.Tensor, missing_as_one: bool = False) -> torch.Tensor:
10-
"""
11-
compute intersection over union on average from confusion matrix
12-
Parameters
13-
Parameters
14-
----------
15-
confusion_matrix: torch.Tensor
16-
square matrix
17-
missing_as_one: bool, default: False
18-
"""
19-
20-
values, existing_classes_mask = compute_intersection_union_per_class(confusion_matrix, return_existing_mask=True)
21-
if torch.sum(existing_classes_mask) == 0:
22-
return torch.sum(existing_classes_mask)
23-
if missing_as_one:
24-
values[~existing_classes_mask] = 1
25-
existing_classes_mask[:] = True
26-
return torch.sum(values[existing_classes_mask]) / torch.sum(existing_classes_mask)
27-
28-
29-
def compute_mean_class_accuracy(confusion_matrix: torch.Tensor) -> torch.Tensor:
30-
"""
31-
compute intersection over union on average from confusion matrix
32-
33-
Parameters
34-
----------
35-
confusion_matrix: torch.Tensor
36-
square matrix
37-
"""
38-
total_gts = confusion_matrix.sum(1)
39-
labels_presents = torch.where(total_gts > 0)[0]
40-
if len(labels_presents) == 0:
41-
return total_gts[0]
42-
ones = torch.ones_like(total_gts)
43-
max_ones_total_gts = torch.cat([total_gts[None, :], ones[None, :]], 0).max(0)[0]
44-
re = (torch.diagonal(confusion_matrix)[labels_presents].float() / max_ones_total_gts[labels_presents]).sum()
45-
return re / float(len(labels_presents))
46-
47-
48-
def compute_overall_accuracy(confusion_matrix: torch.Tensor) -> Union[int, torch.Tensor]:
49-
"""
50-
compute overall accuracy from confusion matrix
51-
52-
Parameters
53-
----------
54-
confusion_matrix: torch.Tensor
55-
square matrix
56-
"""
57-
all_values = confusion_matrix.sum()
58-
if all_values == 0:
59-
return 0
60-
matrix_diagonal = torch.trace(confusion_matrix)
61-
return matrix_diagonal.float() / all_values
62-
63-
64-
def compute_intersection_union_per_class(
65-
confusion_matrix: torch.Tensor, return_existing_mask: bool = False, eps: float = 1e-8
66-
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
67-
"""
68-
compute intersection over union per class from confusion matrix
69-
70-
Parameters
71-
----------
72-
confusion_matrix: torch.Tensor
73-
square matrix
74-
"""
75-
76-
TP_plus_FN = confusion_matrix.sum(0)
77-
TP_plus_FP = confusion_matrix.sum(1)
78-
TP = torch.diagonal(confusion_matrix)
79-
union = TP_plus_FN + TP_plus_FP - TP
80-
iou = eps + TP / (union + eps)
81-
existing_class_mask = union > 1e-3
82-
if return_existing_mask:
83-
return iou, existing_class_mask
84-
else:
85-
return iou, None
7+
import torch_points3d.metrics.segmentation.metrics as mt
868

879

8810
class SegmentationTracker(BaseTracker):
@@ -104,10 +26,10 @@ def __init__(
10426
self.eps = eps
10527

10628
def compute_metrics_from_cm(self, matrix: torch.Tensor) -> Dict[str, Any]:
107-
acc = compute_overall_accuracy(matrix)
108-
macc = compute_mean_class_accuracy(matrix)
109-
miou = compute_average_intersection_union(matrix)
110-
iou_per_class, _ = compute_intersection_union_per_class(matrix, eps=self.eps)
29+
acc = mt.compute_overall_accuracy(matrix)
30+
macc = mt.compute_mean_class_accuracy(matrix)
31+
miou = mt.compute_average_intersection_union(matrix)
32+
iou_per_class, _ = mt.compute_intersection_union_per_class(matrix, eps=self.eps)
11133
iou_per_class_dict = {f"{self.stage}_iou_class_{i}": (100 * v) for i, v in enumerate(iou_per_class)}
11234
res = {
11335
"{}_acc".format(self.stage): 100 * acc,

torch_points3d/models/base_model.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,9 @@ def forward(self) -> Optional[torch.Tensor]:
2323
def compute_loss(self):
2424
raise (NotImplementedError("get_losses needs to be defined!"))
2525

26-
def get_losses(self, stage: Optional[str] = None) -> Optional[Dict["str", torch.Tensor]]:
27-
if stage is None:
28-
return self._losses
29-
else:
30-
losses = {}
31-
for name, l in self._losses.items():
32-
losses[f"{stage}_{name}"] = l.item()
33-
return losses
26+
def get_losses(self) -> Optional[Dict["str", torch.Tensor]]:
27+
return self._losses
28+
3429

3530
def get_outputs(self) -> Dict[str, Optional[torch.Tensor]]:
3631
"""

torch_points3d/models/segmentation/base_model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,18 @@ def forward(self) -> Optional[torch.Tensor]:
3232
features = self.backbone(self.input).x
3333
logits = self.head(features)
3434
self._output = F.log_softmax(logits, dim=-1)
35-
self.compute_losses()
36-
if "loss" in self._losses.keys():
37-
return self._losses["loss"]
38-
else:
39-
return None
35+
loss = self.compute_losses()
36+
return loss
4037

4138
def compute_losses(self):
4239
"""
4340
compute every loss. store the total loss in an attribute _loss
4441
"""
4542
if self.labels is not None and self.criterion is not None:
4643
self._losses["loss"] = self.criterion(self._output, self.labels)
44+
return self._losses["loss"]
45+
else:
46+
return None
4747

4848
def get_outputs(self) -> Dict[str, torch.Tensor]:
4949
"""

0 commit comments

Comments
 (0)