-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathlosses.py
More file actions
27 lines (23 loc) · 1.61 KB
/
losses.py
File metadata and controls
27 lines (23 loc) · 1.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import tensorflow as tf
def dice_loss(overly_small_text_region_training_mask, text_region_boundary_training_mask, loss_weight, small_text_weight, score_y_true, score_y_pred):
eps = 1e-5
_training_mask = tf.minimum(overly_small_text_region_training_mask + small_text_weight, 1) * text_region_boundary_training_mask
intersection = tf.reduce_sum(score_y_true * score_y_pred * _training_mask)
union = tf.reduce_sum(score_y_true * _training_mask) + tf.reduce_sum(score_y_pred * _training_mask) + eps
loss = 1. - (2. * intersection / union)
return loss * loss_weight
def rbox_loss(overly_small_text_region_training_mask, text_region_boundary_training_mask, small_text_weight, target_score_map, geo_y_true, geo_y_pred):
# d1 -> top, d2->right, d3->bottom, d4->left
d1_gt, d2_gt, d3_gt, d4_gt, theta_gt = tf.split(value=geo_y_true, num_or_size_splits=5, axis=3)
d1_pred, d2_pred, d3_pred, d4_pred, theta_pred = tf.split(value=geo_y_pred, num_or_size_splits=5, axis=3)
area_gt = (d1_gt + d3_gt) * (d2_gt + d4_gt)
area_pred = (d1_pred + d3_pred) * (d2_pred + d4_pred)
w_union = tf.minimum(d2_gt, d2_pred) + tf.minimum(d4_gt, d4_pred)
h_union = tf.minimum(d1_gt, d1_pred) + tf.minimum(d3_gt, d3_pred)
area_intersect = w_union * h_union
area_union = area_gt + area_pred - area_intersect
L_AABB = -tf.math.log((area_intersect + 1.0) / (area_union + 1.0))
L_theta = 1 - tf.cos(theta_pred - theta_gt)
L_g = L_AABB + 20 * L_theta
_training_mask = tf.minimum(overly_small_text_region_training_mask + small_text_weight, 1) * text_region_boundary_training_mask
return tf.reduce_mean(L_g * target_score_map * _training_mask)