Skip to content

Commit d9b8d00

Browse files
committed
readme and improvements
1 parent 62cfcb7 commit d9b8d00

File tree

8 files changed

+227
-22
lines changed

8 files changed

+227
-22
lines changed

README.md

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,97 @@ From Source:
2020
pip install -r requirements.txt
2121
pip install .
2222
```
23+
24+
# Datasets
25+
26+
Datasets have the API:
27+
- `__getitem__`
28+
- `__len__`
29+
They all subclass from `torch.utils.data.Dataset`
30+
Hence, they can all be multi-threaded (python multiprocessing) using standard torch.utils.data.DataLoader.
31+
32+
For example:
33+
34+
`torch.utils.data.DataLoader(coco_cap, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads)`
35+
36+
In the constructor, each dataset has a slightly different API as needed, but they all take the keyword args:
37+
38+
- `transform` - a function that takes in an image and returns a transformed version
39+
- common stuff like `ToTensor`, `RandomCrop`, etc. These can be composed together with `transforms.Compose` (see transforms section below)
40+
- `target_transform` - a function that takes in the target and transforms it. For example, take in the caption string and return a tensor of word indices.
41+
42+
The following datasets are available:
43+
44+
- COCO (Captioning and Detection)
45+
- LSUN Classification
46+
- Imagenet-12
47+
- ImageFolder
48+
49+
### COCO
50+
51+
This requires the [COCO API to be installed](https://github.com/pdollar/coco/tree/master/PythonAPI)
52+
53+
#### Captions:
54+
55+
`dset.CocoCaptions(root="dir where images are", annFile="json annotation file", [transform, target_transform])`
56+
57+
Example:
58+
59+
```python
60+
import torchvision.datasets as dset
61+
import torchvision.transforms as transforms
62+
cap = dset.CocoCaptions(root = 'dir where images are', annFile = 'json annotation file', transform=transforms.toTensor)
63+
64+
print('Number of samples:', len(cap))
65+
img, target = cap[3] # load 4th sample
66+
67+
print(img.size())
68+
print(target)
69+
```
70+
71+
Output:
72+
73+
```
74+
```
75+
76+
#### Detection:
77+
`dset.CocoDetection(root="dir where images are", annFile="json annotation file", [transform, target_transform])`
78+
79+
### LSUN
80+
81+
`dset.LSUN(db_path, classes='train', [transform, target_transform])`
82+
83+
- db_path = root directory for the database files
84+
- classes =
85+
- 'train' - all categories, training set
86+
- 'val' - all categories, validation set
87+
- 'test' - all categories, test set
88+
- ['bedroom_train', 'church_train', ...] : a list of categories to load
89+
90+
91+
### ImageFolder
92+
93+
A generic data loader where the images are arranged in this way:
94+
95+
```
96+
root/dog/xxx.png
97+
root/dog/xxy.png
98+
root/dog/xxz.png
99+
100+
root/cat/123.png
101+
root/cat/nsdf3.png
102+
root/cat/asd932_.png
103+
```
104+
105+
`dset.ImageFolder(root="root folder path", [transform, target_transform])`
106+
107+
It has the members:
108+
109+
- `self.classes` - The class names as a list
110+
- `self.class_to_idx` - Corresponding class indices
111+
- `self.imgs` - The list of (image path, class-index) tuples
112+
113+
114+
### Imagenet-12
115+
116+
This is simply implemented with an ImageFolder dataset, after the data is preprocessed [as described here](https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md#download-the-imagenet-dataset)

setup.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,6 @@
99
long_description = '''torch-vision provides DataLoaders, Pre-trained models
1010
and common transforms for torch for images and videos'''
1111

12-
excluded = ['test']
13-
def exclude_package(pkg):
14-
for exclude in excluded:
15-
if pkg.startswith(exclude):
16-
return True
17-
return False
18-
19-
def create_package_list(base_package):
20-
return ([base_package] +
21-
[base_package + '.' + pkg
22-
for pkg
23-
in find_packages(base_package)
24-
if not exclude_package(pkg)])
25-
26-
2712
setup_info = dict(
2813
# Metadata
2914
name='torchvision',
@@ -36,7 +21,7 @@ def create_package_list(base_package):
3621
license='BSD',
3722

3823
# Package info
39-
packages=find_packages(exclude=('test',)), #create_package_list('torchvision'),
24+
packages=find_packages(exclude=('test',)),
4025

4126
zip_safe=True,
4227
)

test/smoke_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
import torch
22
import torchvision
3+
import torchvision.datasets as dset
4+
import torchvision.transforms

torchvision/datasets/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,6 @@
22
from .folder import ImageFolderDataset
33
from .coco import CocoCaptionsDataset, CocoDetectionDataset
44

5-
__all__ = ('LSUNDataset', 'LSUNClassDataset')
5+
__all__ = ('LSUNDataset', 'LSUNClassDataset',
6+
'ImageFolderDataset',
7+
'CocoCaptionsDataset', 'CocoDetectionDataset')

torchvision/datasets/coco.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import os.path
55

6-
class CocoCaptionsDataset(data.Dataset):
6+
class CocoCaptions(data.Dataset):
77
def __init__(self, root, annFile, transform=None, target_transform=None):
88
from pycocotools.coco import COCO
99
self.root = root
@@ -33,7 +33,7 @@ def __getitem__(self, index):
3333
def __len__(self):
3434
return len(self.ids)
3535

36-
class CocoDetectionDataset(data.Dataset):
36+
class CocoDetection(data.Dataset):
3737
def __init__(self, root, annFile, transform=None, target_transform=None):
3838
from pycocotools.coco import COCO
3939
self.root = root

torchvision/datasets/folder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def make_dataset(dir, class_to_idx):
3333

3434
return images
3535

36-
class ImageFolderDataset(data.Dataset):
36+
class ImageFolder(data.Dataset):
3737
def __init__(self, root, transform=None, target_transform=None):
3838
classes, class_to_idx = find_classes(root)
3939
imgs = make_dataset(root, class_to_idx)

torchvision/datasets/lsun.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
else:
1111
import pickle
1212

13-
class LSUNClassDataset(data.Dataset):
13+
class LSUNClass(data.Dataset):
1414
def __init__(self, db_path, transform=None, target_transform=None):
1515
import lmdb
1616
self.db_path = db_path
@@ -53,7 +53,7 @@ def __len__(self):
5353
def __repr__(self):
5454
return self.__class__.__name__ + ' (' + self.db_path + ')'
5555

56-
class LSUNDataset(data.Dataset):
56+
class LSUN(data.Dataset):
5757
"""
5858
db_path = root directory for the database files
5959
classes = 'train' | 'val' | 'test' | ['bedroom_train', 'church_train', ...]

torchvision/transforms.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import torch
2+
import math
3+
import random
4+
from PIL import Image
5+
6+
7+
class Compose(object):
8+
def __init__(self, transforms):
9+
self.transforms = transforms
10+
11+
def __call__(self, img):
12+
for t in self.transforms:
13+
img = t(img)
14+
return img
15+
16+
17+
class ToTensor(object):
18+
def __call__(self, pic):
19+
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
20+
img = img.view(pic.size[0], pic.size[1], 3)
21+
# put it in CHW format
22+
# yikes, this transpose takes 80% of the loading time/CPU
23+
img = img.transpose(0, 2).transpose(1, 2).contiguous()
24+
return img.float()
25+
26+
class Normalize(object):
27+
def __init__(self, mean, std):
28+
self.mean = mean
29+
self.std = std
30+
31+
def __call__(self, tensor):
32+
for t, m, s in zip(tensor, self.mean, self.std):
33+
t.sub_(m).div_(s)
34+
return tensor
35+
36+
37+
class Scale(object):
38+
"Scales the smaller edge to size"
39+
def __init__(self, size, interpolation=Image.BILINEAR):
40+
self.size = size
41+
self.interpolation = interpolation
42+
43+
def __call__(self, img):
44+
w, h = img.size
45+
if (w <= h and w == self.size) or (h <= w and h == self.size):
46+
return img
47+
if w < h:
48+
return img.resize((w, int(round(h / w * self.size))), self.interpolation)
49+
else:
50+
return img.resize((int(round(w / h * self.size)), h), self.interpolation)
51+
52+
53+
class CenterCrop(object):
54+
"Crop to centered rectangle"
55+
def __init__(self, size):
56+
self.size = size
57+
58+
def __call__(self, img):
59+
w, h = img.size
60+
x1 = int(round((w - self.size) / 2))
61+
y1 = int(round((h - self.size) / 2))
62+
return img.crop((x1, y1, x1 + self.size, y1 + self.size))
63+
64+
65+
class RandomCrop(object):
66+
"Random crop form larger image with optional zero padding"
67+
def __init__(self, size, padding=0):
68+
self.size = size
69+
self.padding = padding
70+
71+
def __call__(self, img):
72+
if self.padding > 0:
73+
raise NotImplementedError()
74+
75+
w, h = img.size
76+
if w == self.size and h == self.size:
77+
return img
78+
79+
x1 = random.randint(0, w - self.size)
80+
y1 = random.randint(0, h - self.size)
81+
return img.crop((x1, y1, x1 + self.size, y1 + self.size))
82+
83+
84+
class RandomHorizontalFlip(object):
85+
"Horizontal flip with 0.5 probability"
86+
def __call__(self, img):
87+
if random.random() < 0.5:
88+
return img.transpose(Image.FLIP_LEFT_RIGHT)
89+
return img
90+
91+
92+
class RandomSizedCrop(object):
93+
"Random crop with size 0.08-1 and aspect ratio 3/4 - 4/3 (Inception-style)"
94+
def __init__(self, size, interpolation=Image.BILINEAR):
95+
self.size = size
96+
self.interpolation = interpolation
97+
98+
def __call__(self, img):
99+
for attempt in range(10):
100+
area = img.size[0] * img.size[1]
101+
target_area = random.uniform(0.08, 1.0) * area
102+
aspect_ratio = random.uniform(3 / 4, 4 / 3)
103+
104+
w = int(round(math.sqrt(target_area * aspect_ratio)))
105+
h = int(round(math.sqrt(target_area / aspect_ratio)))
106+
107+
if random.random() < 0.5:
108+
w, h = h, w
109+
110+
if w <= img.size[0] and h <= img.size[1]:
111+
x1 = random.randint(0, img.size[0] - w)
112+
y1 = random.randint(0, img.size[1] - h)
113+
114+
img = img.crop((x1, y1, x1 + w, y1 + h))
115+
assert(img.size == (w, h))
116+
117+
return img.resize((self.size, self.size), self.interpolation)
118+
119+
# Fallback
120+
scale = Scale(self.size, interpolation=self.interpolation)
121+
crop = CenterCrop(self.size)
122+
return crop(scale(img))

0 commit comments

Comments
 (0)