File tree Expand file tree Collapse file tree 1 file changed +9
-4
lines changed
segmentation_models_pytorch/losses Expand file tree Collapse file tree 1 file changed +9
-4
lines changed Original file line number Diff line number Diff line change @@ -191,10 +191,15 @@ def soft_tversky_score(
191
191
192
192
"""
193
193
assert output .size () == target .size ()
194
-
195
- output_sum = torch .sum (output , dim = dims )
196
- target_sum = torch .sum (target , dim = dims )
197
- difference = LA .vector_norm (output - target , ord = 1 , dim = dims )
194
+
195
+ if dims is not None :
196
+ output_sum = torch .sum (output , dim = dims )
197
+ target_sum = torch .sum (target , dim = dims )
198
+ difference = LA .vector_norm (output - target , ord = 1 , dim = dims )
199
+ else :
200
+ output_sum = torch .sum (output )
201
+ target_sum = torch .sum (target )
202
+ difference = LA .vector_norm (output - target , ord = 1 )
198
203
199
204
intersection = (output_sum + target_sum - difference ) / 2 # TP
200
205
fp = output_sum - intersection
You can’t perform that action at this time.
0 commit comments