Skip to content

Commit 428cddc

Browse files
Fix the wrong stat & accuracy calculation (#539)
1 parent 06fdd31 commit 428cddc

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

segmentation_models_pytorch/metrics/functional.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,8 @@ def _get_stats_multiclass(
172172
for i in range(batch_size):
173173
target_i = target[i]
174174
output_i = output[i]
175-
matched = target_i * (output_i == target_i)
175+
mask = output_i == target_i
176+
matched = torch.where(mask, target_i, -1)
176177
tp = torch.histc(matched.float(), bins=num_classes, min=0, max=num_classes - 1)
177178
fp = torch.histc(output_i.float(), bins=num_classes, min=0, max=num_classes - 1) - tp
178179
fn = torch.histc(target_i.float(), bins=num_classes, min=0, max=num_classes - 1) - tp
@@ -295,7 +296,7 @@ def _iou_score(tp, fp, fn, tn):
295296

296297

297298
def _accuracy(tp, fp, fn, tn):
298-
return tp / (tp + fp + fn + tn)
299+
return (tp + tn) / (tp + fp + fn + tn)
299300

300301

301302
def _sensitivity(tp, fp, fn, tn):

0 commit comments

Comments
 (0)