Skip to content

Commit d817d87

Browse files
No public description
PiperOrigin-RevId: 673702497
1 parent 549c50d commit d817d87

File tree

1 file changed

+132
-0
lines changed

1 file changed

+132
-0
lines changed

official/projects/mosaic/configs/mosaic_config.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,3 +233,135 @@ def mosaic_mnv35_cityscapes() -> cfg.ExperimentConfig:
233233
])
234234

235235
return config
236+
237+
238+
@exp_factory.register_config_factory('mosaic_mnv4_cityscapes')
239+
def mosaic_mnv4_cityscapes() -> cfg.ExperimentConfig:
240+
"""Instantiates an experiment configuration of image segmentation task.
241+
242+
This image segmentation experiment is conducted on Cityscapes dataset. The
243+
model architecture is a MOSAIC encoder-decoer. The default backbone network is
244+
an experimental mobilenet V4 variant on top of which the MOSAIC
245+
encoder-decoder can be deployed. All detailed configurations can be overridden
246+
by a .yaml file provided by the user to launch the experiments. Please refer
247+
to .yaml examples in the path of ../configs/experiments/.
248+
249+
Returns:
250+
A particular instance of cfg.ExperimentConfig for MOSAIC model based
251+
image semantic segmentation task.
252+
"""
253+
train_batch_size = 16
254+
eval_batch_size = 16
255+
steps_per_epoch = CITYSCAPES_TRAIN_EXAMPLES // train_batch_size
256+
output_stride = 16
257+
258+
backbone_output_level = int(math.log2(output_stride))
259+
config = cfg.ExperimentConfig(
260+
task=MosaicSemanticSegmentationTask(
261+
model=MosaicSemanticSegmentationModel(
262+
# Cityscapes uses only 19 semantic classes for train/evaluation.
263+
# The void (background) class is ignored in train and evaluation.
264+
num_classes=19,
265+
input_size=[None, None, 3],
266+
backbone=backbones.Backbone(
267+
type='mobilenet',
268+
mobilenet=backbones.MobileNet(
269+
model_id='MobileNetV4ConvMediumSeg',
270+
output_intermediate_endpoints=True,
271+
output_stride=output_stride)),
272+
neck=MosaicEncoderNeck(
273+
encoder_input_level=backbone_output_level,
274+
branch_filter_depths=[64, 64],
275+
conv_kernel_sizes=[3, 5],
276+
pyramid_pool_bin_nums=[1, 4, 8, 16], # paper default
277+
activation='relu',
278+
dropout_rate=0.1,
279+
kernel_initializer='glorot_uniform',
280+
interpolation='bilinear',
281+
use_depthwise_convolution=True),
282+
head=MosaicDecoderHead(
283+
num_classes=19,
284+
decoder_input_levels=['3/depthwise', '2/depthwise'],
285+
decoder_stage_merge_styles=['concat_merge', 'sum_merge'],
286+
decoder_filters=[64, 64],
287+
decoder_projected_filters=[19, 19],
288+
encoder_end_level=backbone_output_level,
289+
use_additional_classifier_layer=False,
290+
classifier_kernel_size=1,
291+
activation='relu',
292+
kernel_initializer='glorot_uniform',
293+
interpolation='bilinear',
294+
),
295+
norm_activation=common.NormActivation(
296+
activation='relu',
297+
norm_momentum=0.99,
298+
norm_epsilon=1e-3,
299+
use_sync_bn=True,
300+
),
301+
),
302+
losses=seg_cfg.Losses(l2_weight_decay=4e-5),
303+
train_data=seg_cfg.DataConfig(
304+
input_path=os.path.join(
305+
CITYSCAPES_INPUT_PATH_BASE, 'train_fine**'
306+
),
307+
crop_size=[1024, 2048],
308+
output_size=[1024, 2048],
309+
is_training=True,
310+
global_batch_size=train_batch_size,
311+
aug_scale_min=0.5,
312+
aug_scale_max=2.0,
313+
),
314+
validation_data=seg_cfg.DataConfig(
315+
input_path=os.path.join(CITYSCAPES_INPUT_PATH_BASE, 'val_fine*'),
316+
output_size=[1024, 2048],
317+
is_training=False,
318+
global_batch_size=eval_batch_size,
319+
resize_eval_groundtruth=True,
320+
drop_remainder=False,
321+
),
322+
# Imagenet pre-trained MobileNetV4ConvMediumSeg checkpoint.
323+
init_checkpoint=(
324+
'gs://tf_model_garden/vision/mobilenet/v4_seg_float//'
325+
),
326+
init_checkpoint_modules='backbone',
327+
),
328+
trainer=cfg.TrainerConfig(
329+
steps_per_loop=steps_per_epoch,
330+
summary_interval=steps_per_epoch,
331+
checkpoint_interval=steps_per_epoch,
332+
train_steps=100000,
333+
validation_steps=CITYSCAPES_VAL_EXAMPLES // eval_batch_size,
334+
validation_interval=steps_per_epoch,
335+
best_checkpoint_eval_metric='mean_iou',
336+
best_checkpoint_export_subdir='best_ckpt',
337+
best_checkpoint_metric_comp='higher',
338+
optimizer_config=optimization.OptimizationConfig({
339+
'optimizer': {
340+
'type': 'sgd',
341+
'sgd': {
342+
'momentum': 0.9
343+
}
344+
},
345+
'learning_rate': {
346+
'type': 'polynomial',
347+
'polynomial': {
348+
'initial_learning_rate': 0.1,
349+
'decay_steps': 100000,
350+
'end_learning_rate': 0.0,
351+
'power': 0.9
352+
}
353+
},
354+
'warmup': {
355+
'type': 'linear',
356+
'linear': {
357+
'warmup_steps': 5 * steps_per_epoch,
358+
'warmup_learning_rate': 0
359+
}
360+
}
361+
})),
362+
restrictions=[
363+
'task.train_data.is_training != None',
364+
'task.validation_data.is_training != None'
365+
])
366+
367+
return config

0 commit comments

Comments
 (0)