Skip to content

Commit 0f052d5

Browse files
author
humanpose1
committed
test + metrics
1 parent 545f74d commit 0f052d5

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

test/test_confusion_matrix.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
sys.path.insert(0, ROOT)
1212
sys.path.append('.')
1313

14-
from torch_points3d.metrics.segmentation.segmentation_tracker import compute_intersection_union_per_class
15-
from torch_points3d.metrics.segmentation.segmentation_tracker import compute_average_intersection_union
16-
from torch_points3d.metrics.segmentation.segmentation_tracker import compute_overall_accuracy
17-
from torch_points3d.metrics.segmentation.segmentation_tracker import compute_mean_class_accuracy
14+
from torch_points3d.metrics.segmentation.metrics import compute_intersection_union_per_class
15+
from torch_points3d.metrics.segmentation.metrics import compute_average_intersection_union
16+
from torch_points3d.metrics.segmentation.metrics import compute_overall_accuracy
17+
from torch_points3d.metrics.segmentation.metrics import compute_mean_class_accuracy
1818

1919

2020

torch_points3d/metrics/base_tracker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ def track(self, output_model, *args, **kwargs) -> Dict[str, Any]:
2121
def track_loss(self, losses: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
2222
out_loss = dict()
2323
for key, loss in losses.items():
24-
loss_key = "%s_%s" % (self.stage, key)
25-
if loss_key not in self.loss_metrics:
26-
self.loss_metrics[loss_key] = AverageMeter()
24+
loss_key = f"{self.stage}_{key}"
25+
if loss_key not in self.loss_metrics.keys():
26+
self.loss_metrics[loss_key] = AverageMeter().to(loss)
2727
val = self.loss_metrics[loss_key](loss)
2828
out_loss[loss_key] = val
2929
return out_loss

torch_points3d/tasks/base_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def configure_metrics(self, stage: str) -> Optional[Any]:
9898
def _step(self, batch, batch_idx, stage: str):
9999
self.model.set_input(batch)
100100
loss = self.model.forward()
101-
losses = self.model.get_losses(stage=stage)
101+
losses = self.model.get_losses()
102102
outputs = self.model.get_outputs()
103103
metric_dict = self.tracker(outputs, losses)
104104
self.log_dict(metric_dict, prog_bar=True, on_step=False, on_epoch=True)

0 commit comments

Comments
 (0)