-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_dataset_npy.py
More file actions
41 lines (33 loc) · 1.14 KB
/
generate_dataset_npy.py
File metadata and controls
41 lines (33 loc) · 1.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import random
import numpy as np
from deep_learning.dataset import LoopedVideoASLDataset
from utils.parser import get_parser
import torchvision
from tqdm import tqdm
def main():
parser = get_parser()
args = parser.parse_args()
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
transforms = torchvision.transforms.Compose(
[
torchvision.transforms.ToPILImage("RGB"),
torchvision.transforms.Resize((256, 256)),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor()
]
)
sel_labels = ["SignType"]
dataset = LoopedVideoASLDataset("WLASL2000", "reduced_SignData.csv", sel_labels=sel_labels,
drop_features=[],
different_length=True, transform=None)
dataset.set_transforms(transforms)
for i in tqdm(range(len(dataset))):
sample = dataset[i]
video_name = dataset.motions_keys[i]
np.save("data/npy/videos/{}.npy".format(video_name), sample[0].numpy())
exit()
if __name__ == '__main__':
main()