Skip to content

Commit 95f2b04

Browse files
richardaecntensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 338348738
1 parent b8c62df commit 95f2b04

File tree

1 file changed

+38
-2
lines changed

1 file changed

+38
-2
lines changed

official/vision/beta/configs/video_classification.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,17 @@ class DataConfig(cfg.DataConfig):
5050
min_image_size: int = 256
5151

5252

53+
def kinetics400(is_training):
54+
"""Generated Kinectics 400 dataset configs."""
55+
return DataConfig(
56+
name='kinetics400',
57+
num_classes=400,
58+
is_training=is_training,
59+
split='train' if is_training else 'valid',
60+
num_examples=215570 if is_training else 17706,
61+
feature_shape=(64, 224, 224, 3) if is_training else (250, 224, 224, 3))
62+
63+
5364
def kinetics600(is_training):
5465
"""Generated Kinectics 600 dataset configs."""
5566
return DataConfig(
@@ -153,9 +164,35 @@ def video_classification() -> cfg.ExperimentConfig:
153164
])
154165

155166

167+
@exp_factory.register_config_factory('video_classification_kinetics400')
168+
def video_classification_kinetics400() -> cfg.ExperimentConfig:
169+
"""Video classification on Kinectics 400 with resnet."""
170+
train_dataset = kinetics400(is_training=True)
171+
validation_dataset = kinetics400(is_training=False)
172+
task = VideoClassificationTask(
173+
model=VideoClassificationModel(
174+
backbone=backbones_3d.Backbone3D(
175+
type='resnet_3d', resnet_3d=backbones_3d.ResNet3D50()),
176+
norm_activation=common.NormActivation(
177+
norm_momentum=0.9, norm_epsilon=1e-5)),
178+
losses=Losses(l2_weight_decay=1e-4),
179+
train_data=train_dataset,
180+
validation_data=validation_dataset)
181+
config = cfg.ExperimentConfig(
182+
runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
183+
task=task,
184+
restrictions=[
185+
'task.train_data.is_training != None',
186+
'task.validation_data.is_training != None',
187+
'task.train_data.num_classes == task.validation_data.num_classes',
188+
])
189+
add_trainer(config, train_batch_size=1024, eval_batch_size=64)
190+
return config
191+
192+
156193
@exp_factory.register_config_factory('video_classification_kinetics600')
157194
def video_classification_kinetics600() -> cfg.ExperimentConfig:
158-
"""Video classification on Videonet with resnet."""
195+
"""Video classification on Kinectics 600 with resnet."""
159196
train_dataset = kinetics600(is_training=True)
160197
validation_dataset = kinetics600(is_training=False)
161198
task = VideoClassificationTask(
@@ -176,5 +213,4 @@ def video_classification_kinetics600() -> cfg.ExperimentConfig:
176213
'task.train_data.num_classes == task.validation_data.num_classes',
177214
])
178215
add_trainer(config, train_batch_size=1024, eval_batch_size=64)
179-
180216
return config

0 commit comments

Comments
 (0)