@@ -157,15 +157,7 @@ def soft_jaccard_score(
157
157
dims = None ,
158
158
) -> torch .Tensor :
159
159
assert output .size () == target .size ()
160
- if dims is not None :
161
- intersection = torch .sum (output * target , dim = dims )
162
- cardinality = torch .sum (output + target , dim = dims )
163
- else :
164
- intersection = torch .sum (output * target )
165
- cardinality = torch .sum (output + target )
166
-
167
- union = cardinality - intersection
168
- jaccard_score = (intersection + smooth ) / (union + smooth ).clamp_min (eps )
160
+ jaccard_score = soft_tversky_score (output , target , 1.0 , 1.0 , smooth , eps , dims )
169
161
return jaccard_score
170
162
171
163
@@ -177,13 +169,7 @@ def soft_dice_score(
177
169
dims = None ,
178
170
) -> torch .Tensor :
179
171
assert output .size () == target .size ()
180
- if dims is not None :
181
- intersection = torch .sum (output * target , dim = dims )
182
- cardinality = torch .sum (output + target , dim = dims )
183
- else :
184
- intersection = torch .sum (output * target )
185
- cardinality = torch .sum (output + target )
186
- dice_score = (2.0 * intersection + smooth ) / (cardinality + smooth ).clamp_min (eps )
172
+ dice_score = soft_tversky_score (output , target , 0.5 , 0.5 , smooth , eps , dims )
187
173
return dice_score
188
174
189
175
@@ -196,15 +182,28 @@ def soft_tversky_score(
196
182
eps : float = 1e-7 ,
197
183
dims = None ,
198
184
) -> torch .Tensor :
185
+ """Tversky loss
186
+
187
+ References:
188
+ https://arxiv.org/pdf/2302.05666
189
+ https://arxiv.org/pdf/2303.16296
190
+
191
+ """
199
192
assert output .size () == target .size ()
200
193
if dims is not None :
201
- intersection = torch .sum (output * target , dim = dims ) # TP
202
- fp = torch .sum (output * (1.0 - target ), dim = dims )
203
- fn = torch .sum ((1 - output ) * target , dim = dims )
194
+ difference = torch .norm (output - target , p = 1 , dim = dims )
195
+ output_sum = torch .sum (output , dim = dims )
196
+ target_sum = torch .sum (target , dim = dims )
197
+ intersection = (output_sum + target_sum - difference ) / 2 # TP
198
+ fp = output_sum - intersection
199
+ fn = target_sum - intersection
204
200
else :
205
- intersection = torch .sum (output * target ) # TP
206
- fp = torch .sum (output * (1.0 - target ))
207
- fn = torch .sum ((1 - output ) * target )
201
+ difference = torch .norm (output - target , p = 1 )
202
+ output_sum = torch .sum (output )
203
+ target_sum = torch .sum (target )
204
+ intersection = (output_sum + target_sum - difference ) / 2 # TP
205
+ fp = output_sum - intersection
206
+ fn = target_sum - intersection
208
207
209
208
tversky_score = (intersection + smooth ) / (
210
209
intersection + alpha * fp + beta * fn + smooth
0 commit comments