Skip to content

Commit 6a834e9

Browse files
authored
Move resnet video models to single location (#1190)
* [WIP] Minor cleanups on R3d * Move all models to video/resnet.py * Remove old files * Make tests less memory intensive * Lint * Fix typo and add pretraing arg to training script
1 parent 4ec38d4 commit 6a834e9

File tree

10 files changed

+344
-450
lines changed

10 files changed

+344
-450
lines changed

references/video_classification/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,7 @@ def main(args):
201201
pin_memory=True, collate_fn=collate_fn)
202202

203203
print("Creating model")
204-
# model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained)
205-
model = torchvision.models.video.__dict__[args.model]()
204+
model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained)
206205
model.to(device)
207206
if args.distributed and args.sync_bn:
208207
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

test/test_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def _test_detection_model(self, name):
6161
def _test_video_model(self, name):
6262
# the default input shape is
6363
# bs * num_channels * clip_len * h *w
64-
input_shape = (1, 3, 8, 112, 112)
64+
input_shape = (1, 3, 4, 112, 112)
6565
# test both basicblock and Bottleneck
6666
model = models.video.__dict__[name](num_classes=50)
6767
x = torch.rand(input_shape)
@@ -145,6 +145,7 @@ def do_test(self, model_name=model_name):
145145

146146
setattr(Tester, "test_" + model_name, do_test)
147147

148+
148149
for model_name in get_available_video_models():
149150

150151
def do_test(self, model_name=model_name):
Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1 @@
1-
from .r3d import *
2-
from .r2plus1d import *
3-
from .mixed_conv import *
1+
from .resnet import *

torchvision/models/video/_utils.py

Lines changed: 0 additions & 72 deletions
This file was deleted.

torchvision/models/video/mixed_conv.py

Lines changed: 0 additions & 49 deletions
This file was deleted.

torchvision/models/video/r2plus1d.py

Lines changed: 0 additions & 43 deletions
This file was deleted.

torchvision/models/video/r3d.py

Lines changed: 0 additions & 43 deletions
This file was deleted.

0 commit comments

Comments
 (0)