Skip to content

Commit 98a558b

Browse files
yuanliangzhetensorflower-gardener
authored andcommitted
#movinet Add se_type option in tools/convert_3d_2plus1d.py
PiperOrigin-RevId: 420368239
1 parent d58be67 commit 98a558b

File tree

2 files changed

+5
-0
lines changed

2 files changed

+5
-0
lines changed

official/vision/beta/projects/movinet/tools/convert_3d_2plus1d.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
'Export path to save the saved_model file.')
3030
flags.DEFINE_string(
3131
'model_id', 'a0', 'MoViNet model name.')
32+
flags.DEFINE_string(
33+
'se_type', '2plus3d', 'MoViNet model SE type.')
3234
flags.DEFINE_bool(
3335
'causal', True, 'Run the model in causal mode.')
3436
flags.DEFINE_bool(
@@ -46,6 +48,7 @@ def main(_) -> None:
4648
backbone_2plus1d = movinet.Movinet(
4749
model_id=FLAGS.model_id,
4850
causal=FLAGS.causal,
51+
se_type=FLAGS.se_type,
4952
conv_type='2plus1d',
5053
use_positional_encoding=FLAGS.use_positional_encoding)
5154
model_2plus1d = movinet_model.MovinetClassifier(
@@ -56,6 +59,7 @@ def main(_) -> None:
5659
backbone_3d_2plus1d = movinet.Movinet(
5760
model_id=FLAGS.model_id,
5861
causal=FLAGS.causal,
62+
se_type=FLAGS.se_type,
5963
conv_type='3d_2plus1d',
6064
use_positional_encoding=FLAGS.use_positional_encoding)
6165
model_3d_2plus1d = movinet_model.MovinetClassifier(

official/vision/beta/projects/movinet/tools/convert_3d_2plus1d_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def test_convert_model(self):
3636
model_3d_2plus1d = movinet_model.MovinetClassifier(
3737
backbone=movinet.Movinet(
3838
model_id='a0',
39+
se_type='2plus3d',
3940
conv_type='3d_2plus1d'),
4041
num_classes=600)
4142
model_3d_2plus1d.build([1, 1, 1, 1, 3])

0 commit comments

Comments
 (0)