Skip to content

Commit a619f51

Browse files
committed
Internal change
PiperOrigin-RevId: 509322980
1 parent a380ae7 commit a619f51

File tree

3 files changed

+141
-6
lines changed

3 files changed

+141
-6
lines changed
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# --experiment_type=retinanet_mobile_coco_qat
2+
# --topology=4x4
3+
# --tpu_platform=df
4+
# COCO mAP: 24.43 from QAT training and 23.1 from the TFLite after conversion.
5+
# QAT only supports float32 tpu due to fake-quant op.
6+
runtime:
7+
distribution_strategy: 'tpu'
8+
mixed_precision_dtype: 'float32'
9+
task:
10+
losses:
11+
l2_weight_decay: 0.0
12+
model:
13+
anchor:
14+
anchor_size: 3
15+
aspect_ratios: [0.5, 1.0, 2.0]
16+
num_scales: 3
17+
backbone:
18+
mobilenet:
19+
model_id: 'MobileNetMultiAVG'
20+
filter_size_scale: 1.0
21+
type: 'mobilenet'
22+
decoder:
23+
type: 'fpn'
24+
fpn:
25+
num_filters: 128
26+
use_separable_conv: true
27+
use_keras_layer: true
28+
head:
29+
num_convs: 4
30+
num_filters: 128
31+
use_separable_conv: true
32+
input_size: [256, 256, 3]
33+
max_level: 7
34+
min_level: 3
35+
norm_activation:
36+
activation: 'relu6'
37+
norm_epsilon: 0.001
38+
norm_momentum: 0.99
39+
use_sync_bn: true
40+
train_data:
41+
dtype: 'float32'
42+
global_batch_size: 256
43+
is_training: true
44+
parser:
45+
aug_rand_hflip: true
46+
aug_scale_max: 2.0
47+
aug_scale_min: 0.5
48+
validation_data:
49+
dtype: 'float32'
50+
global_batch_size: 256
51+
is_training: false
52+
drop_remainder: false
53+
quantization:
54+
pretrained_original_checkpoint: 'gs://**/coco_mobilenetv3.5_avg_mobile_tpu/ckpt-277200'
55+
quantize_detection_decoder: true
56+
quantize_detection_head: true
57+
trainer:
58+
best_checkpoint_eval_metric: AP
59+
best_checkpoint_export_subdir: best_ckpt
60+
best_checkpoint_metric_comp: higher
61+
optimizer_config:
62+
learning_rate:
63+
type: 'exponential'
64+
exponential:
65+
decay_rate: 0.96
66+
decay_steps: 231
67+
initial_learning_rate: 0.5
68+
name: 'ExponentialDecay'
69+
offset: 0
70+
staircase: true
71+
steps_per_loop: 462
72+
train_steps: 46200
73+
validation_interval: 462
74+
validation_steps: 20
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# --experiment_type=retinanet_mobile_coco
2+
# COCO AP 24.92%
3+
# Use 4x4 DF for training.
4+
runtime:
5+
distribution_strategy: 'tpu'
6+
mixed_precision_dtype: 'bfloat16'
7+
task:
8+
losses:
9+
l2_weight_decay: 3.0e-05
10+
model:
11+
anchor:
12+
anchor_size: 3
13+
aspect_ratios: [0.5, 1.0, 2.0]
14+
num_scales: 3
15+
backbone:
16+
mobilenet:
17+
model_id: 'MobileNetMultiAVG'
18+
filter_size_scale: 1.0
19+
type: 'mobilenet'
20+
decoder:
21+
type: 'fpn'
22+
fpn:
23+
num_filters: 128
24+
use_separable_conv: true
25+
use_keras_layer: true
26+
head:
27+
num_convs: 4
28+
num_filters: 128
29+
use_separable_conv: true
30+
input_size: [256, 256, 3]
31+
max_level: 7
32+
min_level: 3
33+
norm_activation:
34+
activation: 'relu6'
35+
norm_epsilon: 0.001
36+
norm_momentum: 0.99
37+
use_sync_bn: true
38+
train_data:
39+
dtype: 'bfloat16'
40+
global_batch_size: 256
41+
is_training: true
42+
parser:
43+
aug_rand_hflip: true
44+
aug_scale_max: 2.0
45+
aug_scale_min: 0.5
46+
validation_data:
47+
dtype: 'bfloat16'
48+
global_batch_size: 256
49+
is_training: false
50+
drop_remainder: false
51+
trainer:
52+
optimizer_config:
53+
learning_rate:
54+
stepwise:
55+
boundaries: [263340, 272580]
56+
values: [0.32, 0.032, 0.0032]
57+
type: 'stepwise'
58+
warmup:
59+
linear:
60+
warmup_learning_rate: 0.0067
61+
warmup_steps: 2000
62+
steps_per_loop: 462
63+
train_steps: 277200
64+
validation_interval: 462
65+
validation_steps: 20

official/vision/modeling/layers/detection_generator.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -851,14 +851,10 @@ def _generate_detections_tflite(
851851
config.update({'num_classes': num_classes})
852852

853853
for i in range(min_level, max_level + 1):
854-
scores.append(
855-
tf.sigmoid(
856-
tf.reshape(raw_scores[str(i)], [batch_size, -1, num_classes])
857-
)
858-
)
854+
scores.append(tf.reshape(raw_scores[str(i)], [batch_size, -1, num_classes]))
859855
boxes.append(tf.reshape(raw_boxes[str(i)], [batch_size, -1, 4]))
860856
anchors.append(tf.reshape(anchor_boxes[str(i)], [-1, 4]))
861-
scores = tf.concat(scores, 1)
857+
scores = tf.sigmoid(tf.concat(scores, 1))
862858
boxes = tf.concat(boxes, 1)
863859
anchors = tf.concat(anchors, 0)
864860

0 commit comments

Comments
 (0)