@@ -50,6 +50,17 @@ class DataConfig(cfg.DataConfig):
50
50
min_image_size : int = 256
51
51
52
52
53
+ def kinetics400 (is_training ):
54
+ """Generated Kinectics 400 dataset configs."""
55
+ return DataConfig (
56
+ name = 'kinetics400' ,
57
+ num_classes = 400 ,
58
+ is_training = is_training ,
59
+ split = 'train' if is_training else 'valid' ,
60
+ num_examples = 215570 if is_training else 17706 ,
61
+ feature_shape = (64 , 224 , 224 , 3 ) if is_training else (250 , 224 , 224 , 3 ))
62
+
63
+
53
64
def kinetics600 (is_training ):
54
65
"""Generated Kinectics 600 dataset configs."""
55
66
return DataConfig (
@@ -153,9 +164,35 @@ def video_classification() -> cfg.ExperimentConfig:
153
164
])
154
165
155
166
167
+ @exp_factory .register_config_factory ('video_classification_kinetics400' )
168
+ def video_classification_kinetics400 () -> cfg .ExperimentConfig :
169
+ """Video classification on Kinectics 400 with resnet."""
170
+ train_dataset = kinetics400 (is_training = True )
171
+ validation_dataset = kinetics400 (is_training = False )
172
+ task = VideoClassificationTask (
173
+ model = VideoClassificationModel (
174
+ backbone = backbones_3d .Backbone3D (
175
+ type = 'resnet_3d' , resnet_3d = backbones_3d .ResNet3D50 ()),
176
+ norm_activation = common .NormActivation (
177
+ norm_momentum = 0.9 , norm_epsilon = 1e-5 )),
178
+ losses = Losses (l2_weight_decay = 1e-4 ),
179
+ train_data = train_dataset ,
180
+ validation_data = validation_dataset )
181
+ config = cfg .ExperimentConfig (
182
+ runtime = cfg .RuntimeConfig (mixed_precision_dtype = 'bfloat16' ),
183
+ task = task ,
184
+ restrictions = [
185
+ 'task.train_data.is_training != None' ,
186
+ 'task.validation_data.is_training != None' ,
187
+ 'task.train_data.num_classes == task.validation_data.num_classes' ,
188
+ ])
189
+ add_trainer (config , train_batch_size = 1024 , eval_batch_size = 64 )
190
+ return config
191
+
192
+
156
193
@exp_factory .register_config_factory ('video_classification_kinetics600' )
157
194
def video_classification_kinetics600 () -> cfg .ExperimentConfig :
158
- """Video classification on Videonet with resnet."""
195
+ """Video classification on Kinectics 600 with resnet."""
159
196
train_dataset = kinetics600 (is_training = True )
160
197
validation_dataset = kinetics600 (is_training = False )
161
198
task = VideoClassificationTask (
@@ -176,5 +213,4 @@ def video_classification_kinetics600() -> cfg.ExperimentConfig:
176
213
'task.train_data.num_classes == task.validation_data.num_classes' ,
177
214
])
178
215
add_trainer (config , train_batch_size = 1024 , eval_batch_size = 64 )
179
-
180
216
return config
0 commit comments