1+ from __future__ import division
12import torch
23import math
34import random
45from PIL import Image
56import numpy as np
6-
7+ import numbers
78
89class Compose (object ):
10+ """ Composes several transforms together.
11+ For example:
12+ >>> transforms.Compose([
13+ >>> transforms.CenterCrop(10),
14+ >>> transforms.ToTensor(),
15+ >>> ])
16+ """
917 def __init__ (self , transforms ):
1018 self .transforms = transforms
1119
@@ -16,6 +24,8 @@ def __call__(self, img):
1624
1725
1826class ToTensor (object ):
27+ """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
28+ to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
1929 def __call__ (self , pic ):
2030 if isinstance (pic , np .ndarray ):
2131 # handle numpy array
@@ -24,24 +34,50 @@ def __call__(self, pic):
2434 # handle PIL Image
2535 img = torch .ByteTensor (torch .ByteStorage .from_buffer (pic .tobytes ()))
2636 img = img .view (pic .size [0 ], pic .size [1 ], 3 )
27- # put it in CHW format
37+ # put it from WHC to CHW format
2838 # yikes, this transpose takes 80% of the loading time/CPU
29- img = img .transpose (0 , 2 ).transpose (1 , 2 ).contiguous ()
30- return img .float ()
39+ img = img .transpose (0 , 2 ).contiguous ()
40+ return img .float ().div (255 )
41+
42+ class ToPILImage (object ):
43+ """ Converts a torch.*Tensor of range [0, 1] and shape C x H x W
44+ or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C
45+ to a PIL.Image of range [0, 255]
46+ """
47+ def __call__ (self , pic ):
48+ if isinstance (pic , np .ndarray ):
49+ # handle numpy array
50+ img = Image .fromarray (pic )
51+ else :
52+ npimg = pic .mul (255 ).byte ().numpy ()
53+ npimg = np .transpose (npimg , (1 ,2 ,0 ))
54+ img = Image .fromarray (npimg )
55+ return img
3156
3257class Normalize (object ):
58+ """ Given mean: (R, G, B) and std: (R, G, B),
59+ will normalize each channel of the torch.*Tensor, i.e.
60+ channel = (channel - mean) / std
61+ """
3362 def __init__ (self , mean , std ):
3463 self .mean = mean
3564 self .std = std
3665
3766 def __call__ (self , tensor ):
67+ # TODO: make efficient
3868 for t , m , s in zip (tensor , self .mean , self .std ):
3969 t .sub_ (m ).div_ (s )
4070 return tensor
4171
4272
4373class Scale (object ):
44- "Scales the smaller edge to size"
74+ """ Rescales the input PIL.Image to the given 'size'.
75+ 'size' will be the size of the smaller edge.
76+ For example, if height > width, then image will be
77+ rescaled to (size * height / width, size)
78+ size: size of the smaller edge
79+ interpolation: Default: PIL.Image.BILINEAR
80+ """
4581 def __init__ (self , size , interpolation = Image .BILINEAR ):
4682 self .size = size
4783 self .interpolation = interpolation
@@ -51,52 +87,76 @@ def __call__(self, img):
5187 if (w <= h and w == self .size ) or (h <= w and h == self .size ):
5288 return img
5389 if w < h :
54- return img .resize ((w , int (round (h / w * self .size ))), self .interpolation )
90+ ow = self .size
91+ oh = int (self .size * h / w )
92+ return img .resize ((ow , oh ), self .interpolation )
5593 else :
56- return img .resize ((int (round (w / h * self .size )), h ), self .interpolation )
94+ oh = self .size
95+ ow = int (self .size * w / h )
96+ return img .resize ((ow , oh ), self .interpolation )
5797
5898
5999class CenterCrop (object ):
60- "Crop to centered rectangle"
100+ """Crops the given PIL.Image at the center to have a region of
101+ the given size. size can be a tuple (target_height, target_width)
102+ or an integer, in which case the target will be of a square shape (size, size)
103+ """
61104 def __init__ (self , size ):
62- self .size = size
105+ if isinstance (size , numbers .Number ):
106+ self .size = (int (size ), int (size ))
107+ else :
108+ self .size = size
63109
64110 def __call__ (self , img ):
65111 w , h = img .size
66- x1 = int (round ((w - self .size ) / 2 ))
67- y1 = int (round ((h - self .size ) / 2 ))
68- return img .crop ((x1 , y1 , x1 + self .size , y1 + self .size ))
112+ th , tw = self .size
113+ x1 = int (round ((w - tw ) / 2 ))
114+ y1 = int (round ((h - th ) / 2 ))
115+ return img .crop ((x1 , y1 , x1 + tw , y1 + th ))
69116
70117
71118class RandomCrop (object ):
72- "Random crop form larger image with optional zero padding"
119+ """Crops the given PIL.Image at a random location to have a region of
120+ the given size. size can be a tuple (target_height, target_width)
121+ or an integer, in which case the target will be of a square shape (size, size)
122+ """
73123 def __init__ (self , size , padding = 0 ):
74- self .size = size
124+ if isinstance (size , numbers .Number ):
125+ self .size = (int (size ), int (size ))
126+ else :
127+ self .size = size
75128 self .padding = padding
76129
77130 def __call__ (self , img ):
78131 if self .padding > 0 :
79132 raise NotImplementedError ()
80133
81134 w , h = img .size
82- if w == self .size and h == self .size :
135+ th , tw = self .size
136+ if w == tw and h == th :
83137 return img
84138
85- x1 = random .randint (0 , w - self . size )
86- y1 = random .randint (0 , h - self . size )
87- return img .crop ((x1 , y1 , x1 + self . size , y1 + self . size ))
139+ x1 = random .randint (0 , w - tw )
140+ y1 = random .randint (0 , h - th )
141+ return img .crop ((x1 , y1 , x1 + tw , y1 + th ))
88142
89143
90144class RandomHorizontalFlip (object ):
91- "Horizontal flip with 0.5 probability"
145+ """Randomly horizontally flips the given PIL.Image with a probability of 0.5
146+ """
92147 def __call__ (self , img ):
93148 if random .random () < 0.5 :
94149 return img .transpose (Image .FLIP_LEFT_RIGHT )
95150 return img
96151
97152
98153class RandomSizedCrop (object ):
99- "Random crop with size 0.08-1 and aspect ratio 3/4 - 4/3 (Inception-style)"
154+ """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
155+ and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
156+ This is popularly used to train the Inception networks
157+ size: size of the smaller edge
158+ interpolation: Default: PIL.Image.BILINEAR
159+ """
100160 def __init__ (self , size , interpolation = Image .BILINEAR ):
101161 self .size = size
102162 self .interpolation = interpolation
0 commit comments