Skip to content

Commit 259d429

Browse files
No public description
PiperOrigin-RevId: 671496532
1 parent c8e5cd5 commit 259d429

File tree

3 files changed

+148
-0
lines changed

3 files changed

+148
-0
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Top-1 Acc: 77.15%/ 66.57% Val/Train @ 500 epochs
2+
runtime:
3+
distribution_strategy: 'tpu'
4+
mixed_precision_dtype: 'bfloat16'
5+
task:
6+
model:
7+
num_classes: 1001
8+
input_size: [256, 256, 3]
9+
backbone:
10+
mobilenet:
11+
model_id: 'MobileNetV4ConvMediumSeg'
12+
flat_stochastic_depth_drop_rate: false
13+
stochastic_depth_drop_rate: 0.075
14+
type: 'mobilenet'
15+
norm_activation:
16+
norm_epsilon: 0.001
17+
norm_momentum: 0.997
18+
dropout_rate: 0.2
19+
losses:
20+
l2_weight_decay: 0.0
21+
label_smoothing: 0.1
22+
train_data:
23+
input_path: 'gs://mlcompass-data/imagenet/imagenet-2012-tfrecord/train*'
24+
is_training: true
25+
global_batch_size: 4096
26+
dtype: 'bfloat16'
27+
aug_type:
28+
randaug:
29+
cutout_const: 20
30+
exclude_ops: ['Cutout']
31+
magnitude: 15
32+
prob_to_apply: 0.7
33+
type: 'randaug'
34+
validation_data:
35+
input_path: 'gs://mlcompass-data/imagenet/imagenet-2012-tfrecord/valid*'
36+
is_training: false
37+
global_batch_size: 4096
38+
dtype: 'bfloat16'
39+
drop_remainder: false
40+
trainer:
41+
train_steps: 156000
42+
validation_steps: 13
43+
validation_interval: 312
44+
steps_per_loop: 312
45+
summary_interval: 312
46+
checkpoint_interval: 312
47+
optimizer_config:
48+
learning_rate:
49+
cosine:
50+
decay_steps: 156000
51+
initial_learning_rate: 0.004
52+
name: 'CosineDecay'
53+
type: 'cosine'
54+
optimizer:
55+
adamw:
56+
exclude_from_weight_decay: ['batch_normalization']
57+
gradient_clip_norm: 0.0
58+
weight_decay_rate: 0.1
59+
name: 'AdamWeightDecay'
60+
type: 'adamw'
61+
warmup:
62+
linear:
63+
warmup_steps: 1560
64+
name: 'linear'
65+
type: 'linear'

official/vision/modeling/backbones/mobilenet.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,77 @@ def uib(
766766
}
767767

768768

769+
def _mnv4_conv_medium_seg_block_specs():
770+
"""Tailored MobileNetV4ConvMedium for dense prediction, e.g. segmentation."""
771+
772+
def convbn(kernel_size, strides, filters, output=False):
773+
return BlockSpec(
774+
block_fn='convbn',
775+
kernel_size=kernel_size,
776+
filters=filters,
777+
strides=strides,
778+
is_output=output,
779+
)
780+
781+
def fused_ib(kernel_size, strides, filters, output=False):
782+
return BlockSpec(
783+
block_fn='fused_ib',
784+
kernel_size=kernel_size,
785+
filters=filters,
786+
strides=strides,
787+
expand_ratio=4.0,
788+
is_output=output,
789+
)
790+
791+
def uib(
792+
start_dw_ks, middle_dw_ks, strides, filters, expand_ratio, output=False
793+
):
794+
return BlockSpec(
795+
block_fn='uib',
796+
start_dw_kernel_size=start_dw_ks,
797+
middle_dw_kernel_size=middle_dw_ks,
798+
filters=filters,
799+
strides=strides,
800+
expand_ratio=expand_ratio,
801+
use_layer_scale=False,
802+
is_output=output,
803+
)
804+
805+
blocks = [
806+
convbn(3, 2, 32),
807+
fused_ib(3, 2, 48, output=True),
808+
# 3rd stage
809+
uib(3, 5, 2, 80, 4.0),
810+
uib(3, 3, 1, 80, 2.0, output=True),
811+
# 4th stage
812+
uib(3, 5, 2, 160, 6.0),
813+
uib(3, 3, 1, 160, 4.0),
814+
uib(3, 3, 1, 160, 4.0),
815+
uib(3, 5, 1, 160, 4.0),
816+
uib(3, 3, 1, 160, 4.0),
817+
uib(3, 0, 1, 160, 4.0),
818+
uib(3, 0, 1, 160, 4.0, output=True),
819+
# 5th stage
820+
uib(5, 5, 2, 256, 6.0),
821+
uib(5, 5, 1, 128, 4.0),
822+
uib(3, 5, 1, 128, 4.0),
823+
uib(3, 5, 1, 128, 4.0),
824+
uib(3, 0, 1, 128, 4.0),
825+
uib(3, 5, 1, 128, 2.0),
826+
uib(5, 5, 1, 128, 4.0),
827+
uib(5, 0, 1, 128, 2.0, output=False),
828+
# FC layers
829+
convbn(1, 1, 448, output=True),
830+
BlockSpec(block_fn='gpooling', is_output=False),
831+
convbn(1, 1, 1280),
832+
]
833+
return {
834+
'spec_name': 'MobileNetV4ConvMediumSeg',
835+
'block_spec_schema': block_spec_field_list(),
836+
'block_specs': block_spec_values_to_list(blocks),
837+
}
838+
839+
769840
MNV4ConvLarge_BLOCK_SPECS = {
770841
'spec_name': 'MobileNetV4ConvLarge',
771842
'block_spec_schema': [
@@ -1077,6 +1148,7 @@ def mhsa_12px():
10771148
'MobileNetV4ConvLarge': MNV4ConvLarge_BLOCK_SPECS,
10781149
'MobileNetV4HybridMedium': _mnv4_hybrid_medium_block_specs(),
10791150
'MobileNetV4HybridLarge': _mnv4_hybrid_large_block_specs(),
1151+
'MobileNetV4ConvMediumSeg': _mnv4_conv_medium_seg_block_specs(),
10801152
}
10811153

10821154

official/vision/modeling/backbones/mobilenet_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
4343
'MobileNetV4ConvLarge',
4444
'MobileNetV4HybridMedium',
4545
'MobileNetV4HybridLarge',
46+
'MobileNetV4ConvMediumSeg',
4647
)
4748
def test_serialize_deserialize(self, model_id):
4849
# Create a network object that sets all of its config options.
@@ -96,6 +97,7 @@ def test_serialize_deserialize(self, model_id):
9697
'MobileNetV4ConvLarge',
9798
'MobileNetV4HybridMedium',
9899
'MobileNetV4HybridLarge',
100+
'MobileNetV4ConvMediumSeg',
99101
],
100102
)
101103
)
@@ -126,6 +128,7 @@ def test_input_specs(self, input_dim, model_id):
126128
'MobileNetV4ConvLarge',
127129
'MobileNetV4HybridMedium',
128130
'MobileNetV4HybridLarge',
131+
'MobileNetV4ConvMediumSeg',
129132
],
130133
[32, 224],
131134
)
@@ -153,6 +156,7 @@ def test_mobilenet_creation(self, model_id,
153156
'MobileNetV4ConvLarge': [48, 96, 192, 512],
154157
'MobileNetV4HybridMedium': [48, 80, 160, 256],
155158
'MobileNetV4HybridLarge': [48, 96, 192, 512],
159+
'MobileNetV4ConvMediumSeg': [48, 80, 160, 448],
156160
}
157161

158162
network = mobilenet.MobileNet(model_id=model_id,
@@ -184,6 +188,7 @@ def test_mobilenet_creation(self, model_id,
184188
'MobileNetV4ConvLarge',
185189
'MobileNetV4HybridMedium',
186190
'MobileNetV4HybridLarge',
191+
'MobileNetV4ConvMediumSeg',
187192
],
188193
[32, 224],
189194
)
@@ -211,6 +216,7 @@ def test_mobilenet_intermediate_layers(self, model_id, input_size):
211216
'MobileNetV4ConvLarge': [None, None, None, None],
212217
'MobileNetV4HybridMedium': [None, None, None, None],
213218
'MobileNetV4HybridLarge': [None, None, None, None],
219+
'MobileNetV4ConvMediumSeg': [None, None, None, None],
214220
}
215221
network = mobilenet.MobileNet(model_id=model_id,
216222
filter_size_scale=1.0,
@@ -247,6 +253,7 @@ def test_mobilenet_intermediate_layers(self, model_id, input_size):
247253
'MobileNetV4ConvLarge',
248254
'MobileNetV4HybridMedium',
249255
'MobileNetV4HybridLarge',
256+
'MobileNetV4ConvMediumSeg',
250257
],
251258
[1.0, 0.75],
252259
)
@@ -285,6 +292,8 @@ def test_mobilenet_scaling(self, model_id,
285292
('MobileNetV4HybridMedium', 0.75): 6072584,
286293
('MobileNetV4HybridLarge', 1.0): 36648024,
287294
('MobileNetV4HybridLarge', 0.75): 21598064,
295+
('MobileNetV4ConvMediumSeg', 1.0): 3787024,
296+
('MobileNetV4ConvMediumSeg', 0.75): 2302536,
288297
}
289298

290299
input_size = 224
@@ -314,6 +323,7 @@ def test_mobilenet_scaling(self, model_id,
314323
'MobileNetV4ConvLarge',
315324
'MobileNetV4HybridMedium',
316325
'MobileNetV4HybridLarge',
326+
'MobileNetV4ConvMediumSeg',
317327
],
318328
[8, 16, 32],
319329
)
@@ -340,6 +350,7 @@ def test_mobilenet_output_stride(self, model_id, output_stride):
340350
'MobileNetV4ConvLarge': 512,
341351
'MobileNetV4HybridMedium': 256,
342352
'MobileNetV4HybridLarge': 512,
353+
'MobileNetV4ConvMediumSeg': 448,
343354
}
344355

345356
network = mobilenet.MobileNet(

0 commit comments

Comments
 (0)