|
4 | 4 | from typing import Optional
|
5 | 5 |
|
6 | 6 | import torch
|
| 7 | +import torch.linalg as LA |
7 | 8 | import torch.nn.functional as F
|
8 | 9 |
|
9 | 10 | __all__ = [
|
@@ -190,20 +191,14 @@ def soft_tversky_score(
|
190 | 191 |
|
191 | 192 | """
|
192 | 193 | assert output.size() == target.size()
|
193 |
| - if dims is not None: |
194 |
| - difference = torch.norm(output - target, p=1, dim=dims) |
195 |
| - output_sum = torch.sum(output, dim=dims) |
196 |
| - target_sum = torch.sum(target, dim=dims) |
197 |
| - intersection = (output_sum + target_sum - difference) / 2 # TP |
198 |
| - fp = output_sum - intersection |
199 |
| - fn = target_sum - intersection |
200 |
| - else: |
201 |
| - difference = torch.norm(output - target, p=1) |
202 |
| - output_sum = torch.sum(output) |
203 |
| - target_sum = torch.sum(target) |
204 |
| - intersection = (output_sum + target_sum - difference) / 2 # TP |
205 |
| - fp = output_sum - intersection |
206 |
| - fn = target_sum - intersection |
| 194 | + |
| 195 | + output_sum = torch.sum(output, dim=dims) |
| 196 | + target_sum = torch.sum(target, dim=dims) |
| 197 | + difference = LA.vector_norm(output - target, ord=1, dim=dims) |
| 198 | + |
| 199 | + intersection = (output_sum + target_sum - difference) / 2 # TP |
| 200 | + fp = output_sum - intersection |
| 201 | + fn = target_sum - intersection |
207 | 202 |
|
208 | 203 | tversky_score = (intersection + smooth) / (
|
209 | 204 | intersection + alpha * fp + beta * fn + smooth
|
|
0 commit comments