Skip to content

Commit 62cfcb7

Browse files
committed
first commit
0 parents  commit 62cfcb7

File tree

9 files changed

+328
-0
lines changed

9 files changed

+328
-0
lines changed

README.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# torch-vision
2+
3+
This repository consists of:
4+
5+
- `vision.datasets` : Data loaders for popular vision datasets
6+
- `vision.transforms` : Common image transformations such as random crop, rotations etc.
7+
- `[WIP] vision.models` : Model definitions and Pre-trained models for popular models such as AlexNet, VGG, ResNet etc.
8+
9+
# Installation
10+
11+
Binaries:
12+
13+
```bash
14+
conda install pytorch-vision -c https://conda.anaconda.org/t/6N-MsQ4WZ7jo/soumith
15+
```
16+
17+
From Source:
18+
19+
```bash
20+
pip install -r requirements.txt
21+
pip install .
22+
```

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pillow

setup.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#!/usr/bin/env python
2+
import os
3+
import shutil
4+
import sys
5+
from setuptools import setup, find_packages
6+
7+
VERSION = '0.1.5'
8+
9+
long_description = '''torch-vision provides DataLoaders, Pre-trained models
10+
and common transforms for torch for images and videos'''
11+
12+
excluded = ['test']
13+
def exclude_package(pkg):
14+
for exclude in excluded:
15+
if pkg.startswith(exclude):
16+
return True
17+
return False
18+
19+
def create_package_list(base_package):
20+
return ([base_package] +
21+
[base_package + '.' + pkg
22+
for pkg
23+
in find_packages(base_package)
24+
if not exclude_package(pkg)])
25+
26+
27+
setup_info = dict(
28+
# Metadata
29+
name='torchvision',
30+
version=VERSION,
31+
author='PyTorch Core Team',
32+
author_email='[email protected]',
33+
url='https://github.com/pytorch/vision',
34+
description='image and video datasets and models for torch deep learning',
35+
long_description=long_description,
36+
license='BSD',
37+
38+
# Package info
39+
packages=find_packages(exclude=('test',)), #create_package_list('torchvision'),
40+
41+
zip_safe=True,
42+
)
43+
44+
setup(**setup_info)

test/smoke_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
import torch
2+
import torchvision

torchvision/__init__.py

Whitespace-only changes.

torchvision/datasets/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .lsun import LSUNDataset, LSUNClassDataset
2+
from .folder import ImageFolderDataset
3+
from .coco import CocoCaptionsDataset, CocoDetectionDataset
4+
5+
__all__ = ('LSUNDataset', 'LSUNClassDataset')

torchvision/datasets/coco.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import torch.utils.data as data
2+
from PIL import Image
3+
import os
4+
import os.path
5+
6+
class CocoCaptionsDataset(data.Dataset):
7+
def __init__(self, root, annFile, transform=None, target_transform=None):
8+
from pycocotools.coco import COCO
9+
self.root = root
10+
self.coco = COCO(annFile)
11+
self.ids = self.coco.imgs.keys()
12+
self.transform = transform
13+
self.target_transform = target_transform
14+
15+
def __getitem__(self, index):
16+
coco = self.coco
17+
img_id = self.ids[index]
18+
ann_ids = coco.getAnnIds(imgIds = img_id)
19+
anns = coco.loadAnns(ann_ids)
20+
target = [ann['caption'] for ann in anns]
21+
22+
path = coco.loadImgs(img_id)[0]['file_name']
23+
24+
img = Image.open(os.path.join(self.root, path)).convert('RGB')
25+
if self.transform is not None:
26+
img = self.transform(img)
27+
28+
if self.target_transform is not None:
29+
target = self.target_transform(target)
30+
31+
return img, target
32+
33+
def __len__(self):
34+
return len(self.ids)
35+
36+
class CocoDetectionDataset(data.Dataset):
37+
def __init__(self, root, annFile, transform=None, target_transform=None):
38+
from pycocotools.coco import COCO
39+
self.root = root
40+
self.coco = COCO(annFile)
41+
self.ids = self.coco.imgs.keys()
42+
self.transform = transform
43+
self.target_transform = target_transform
44+
45+
def __getitem__(self, index):
46+
coco = self.coco
47+
img_id = self.ids[index]
48+
ann_ids = coco.getAnnIds(imgIds = img_id)
49+
target = coco.loadAnns(ann_ids)
50+
51+
path = coco.loadImgs(img_id)[0]['file_name']
52+
53+
img = Image.open(os.path.join(self.root, path)).convert('RGB')
54+
if self.target_transform is not None:
55+
target = self.target_transform(target)
56+
57+
return img, target
58+
59+
def __len__(self):
60+
return len(self.ids)

torchvision/datasets/folder.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import torch.utils.data as data
2+
3+
from PIL import Image
4+
import os
5+
import os.path
6+
7+
IMG_EXTENSIONS = [
8+
'.jpg', '.JPG', '.jpeg', '.JPEG',
9+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
10+
]
11+
12+
def is_image_file(filename):
13+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
14+
15+
def find_classes(dir):
16+
classes = os.listdir(dir)
17+
classes.sort()
18+
class_to_idx = {classes[i]: i for i in range(len(classes))}
19+
return classes, class_to_idx
20+
21+
def make_dataset(dir, class_to_idx):
22+
images = []
23+
for target in os.listdir(dir):
24+
d = os.path.join(dir, target)
25+
if not os.path.isdir(d):
26+
continue
27+
28+
for filename in os.listdir(d):
29+
if is_image_file(filename):
30+
path = '{0}/{1}'.format(target, filename)
31+
item = (path, class_to_idx[target])
32+
images.append(item)
33+
34+
return images
35+
36+
class ImageFolderDataset(data.Dataset):
37+
def __init__(self, root, transform=None, target_transform=None):
38+
classes, class_to_idx = find_classes(root)
39+
imgs = make_dataset(root, class_to_idx)
40+
41+
self.root = root
42+
self.imgs = imgs
43+
self.classes = classes
44+
self.class_to_idx = class_to_idx
45+
self.transform = transform
46+
self.target_transform = target_transform
47+
48+
def __getitem__(self, index):
49+
path, target = self.imgs[index]
50+
img = Image.open(os.path.join(self.root, path)).convert('RGB')
51+
if self.transform is not None:
52+
img = self.transform(img)
53+
if self.target_transform is not None:
54+
target = self.target_transform(target)
55+
56+
return img, target
57+
58+
def __len__(self):
59+
return len(self.imgs)

torchvision/datasets/lsun.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import torch.utils.data as data
2+
from PIL import Image
3+
import os
4+
import os.path
5+
import StringIO
6+
import string
7+
import sys
8+
if sys.version_info[0] == 2:
9+
import cPickle as pickle
10+
else:
11+
import pickle
12+
13+
class LSUNClassDataset(data.Dataset):
14+
def __init__(self, db_path, transform=None, target_transform=None):
15+
import lmdb
16+
self.db_path = db_path
17+
self.env = lmdb.open(db_path, map_size=1099511627776,
18+
max_readers=100, readonly=True)
19+
with self.env.begin(write=False) as txn:
20+
self.length = txn.stat()['entries']
21+
cache_file = '_cache_' + db_path.replace('/', '_')
22+
if os.path.isfile(cache_file):
23+
self.keys = pickle.load( open( cache_file, "rb" ) )
24+
else:
25+
with self.env.begin(write=False) as txn:
26+
self.keys = [ key for key, _ in txn.cursor() ]
27+
pickle.dump( self.keys, open( cache_file, "wb" ) )
28+
self.transform = transform
29+
self.target_transform = target_transform
30+
31+
def __getitem__(self, index):
32+
img, target = None, None
33+
env = self.env
34+
with env.begin(write=False) as txn:
35+
imgbuf = txn.get(self.keys[index])
36+
37+
buf = StringIO.StringIO()
38+
buf.write(imgbuf)
39+
buf.seek(0)
40+
img = Image.open(buf).convert('RGB')
41+
42+
if self.transform is not None:
43+
img = self.transform(img)
44+
45+
if self.target_transform is not None:
46+
target = self.target_transform(target)
47+
48+
return img, target
49+
50+
def __len__(self):
51+
return self.length
52+
53+
def __repr__(self):
54+
return self.__class__.__name__ + ' (' + self.db_path + ')'
55+
56+
class LSUNDataset(data.Dataset):
57+
"""
58+
db_path = root directory for the database files
59+
classes = 'train' | 'val' | 'test' | ['bedroom_train', 'church_train', ...]
60+
"""
61+
def __init__(self, db_path, classes='train',
62+
transform=None, target_transform=None):
63+
categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom',
64+
'conference_room', 'dining_room', 'kitchen',
65+
'living_room', 'restaurant', 'tower']
66+
dset_opts = ['train', 'val', 'test']
67+
self.db_path = db_path
68+
if type(classes) == str and classes in dset_opts:
69+
classes = [c + '_' + classes for c in categories]
70+
if type(classes) == list:
71+
for c in classes:
72+
c_short = c.split('_')
73+
c_short.pop(len(c_short) - 1)
74+
c_short = string.join(c_short, '_')
75+
if c_short not in categories:
76+
raise(ValueError('Unknown LSUN class: ' + c_short + '.'\
77+
'Options are: ' + str(categories)))
78+
c_short = c.split('_')
79+
c_short = c_short.pop(len(c_short) - 1)
80+
if c_short not in dset_opts:
81+
raise(ValueError('Unknown postfix: ' + c_short + '.'\
82+
'Options are: ' + str(dset_opts)))
83+
else:
84+
raise(ValueError('Unknown option for classes'))
85+
self.classes = classes
86+
87+
# for each class, create an LSUNClassDataset
88+
self.dbs = []
89+
for c in self.classes:
90+
self.dbs.append(LSUNClassDataset(
91+
db_path = db_path + '/' + c + '_lmdb',
92+
transform = transform))
93+
94+
self.indices = []
95+
count = 0
96+
for db in self.dbs:
97+
count += len(db)
98+
self.indices.append(count)
99+
100+
self.length = count
101+
self.target_transform = target_transform
102+
103+
def __getitem__(self, index):
104+
target = 0
105+
sub = 0
106+
for ind in self.indices:
107+
if index < ind:
108+
break
109+
target += 1
110+
sub += ind
111+
112+
db = self.dbs[target]
113+
index = index - sub
114+
115+
if self.target_transform is not None:
116+
target = self.target_transform(target)
117+
118+
return db[index], target
119+
120+
def __len__(self):
121+
return self.length
122+
123+
def __repr__(self):
124+
return self.__class__.__name__ + ' (' + self.db_path + ')'
125+
126+
if __name__ == '__main__':
127+
#lsun = LSUNClassDataset(db_path='/home/soumith/local/lsun/train/bedroom_train_lmdb')
128+
#a = lsun[0]
129+
lsun = LSUNDataset(db_path='/home/soumith/local/lsun/train',
130+
classes=['bedroom_train', 'church_outdoor_train'])
131+
print(lsun.classes)
132+
print(lsun.dbs)
133+
a, t = lsun[len(lsun)-1]
134+
print(a)
135+
print(t)

0 commit comments

Comments
 (0)