@@ -39,19 +39,30 @@ def __call__(self, pic):
39
39
if isinstance (pic , np .ndarray ):
40
40
# handle numpy array
41
41
img = torch .from_numpy (pic .transpose ((2 , 0 , 1 )))
42
+ # backard compability
43
+ return img .float ().div (255 )
44
+ # handle PIL Image
45
+ if pic .mode == 'I' :
46
+ img = torch .from_numpy (np .array (pic , np .int32 ))
47
+ elif pic .mode == 'I;16' :
48
+ img = torch .from_numpy (np .array (pic , np .int16 ))
42
49
else :
43
- # handle PIL Image
44
50
img = torch .ByteTensor (torch .ByteStorage .from_buffer (pic .tobytes ()))
45
- # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
46
- if pic .mode == 'YCbCr' :
47
- nchannel = 3
48
- else :
49
- nchannel = len (pic .mode )
50
- img = img .view (pic .size [1 ], pic .size [0 ], nchannel )
51
- # put it from HWC to CHW format
52
- # yikes, this transpose takes 80% of the loading time/CPU
53
- img = img .transpose (0 , 1 ).transpose (0 , 2 ).contiguous ()
54
- return img .float ().div (255 )
51
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
52
+ if pic .mode == 'YCbCr' :
53
+ nchannel = 3
54
+ elif pic .mode == 'I;16' :
55
+ nchannel = 1
56
+ else :
57
+ nchannel = len (pic .mode )
58
+ img = img .view (pic .size [1 ], pic .size [0 ], nchannel )
59
+ # put it from HWC to CHW format
60
+ # yikes, this transpose takes 80% of the loading time/CPU
61
+ img = img .transpose (0 , 1 ).transpose (0 , 2 ).contiguous ()
62
+ if isinstance (img , torch .ByteTensor ):
63
+ return img .float ().div (255 )
64
+ else :
65
+ return img
55
66
56
67
57
68
class ToPILImage (object ):
@@ -67,7 +78,6 @@ def __call__(self, pic):
67
78
if torch .is_tensor (pic ):
68
79
npimg = np .transpose (pic .numpy (), (1 , 2 , 0 ))
69
80
assert isinstance (npimg , np .ndarray ), 'pic should be Tensor or ndarray'
70
-
71
81
if npimg .shape [2 ] == 1 :
72
82
npimg = npimg [:, :, 0 ]
73
83
@@ -83,7 +93,6 @@ def __call__(self, pic):
83
93
if npimg .dtype == np .uint8 :
84
94
mode = 'RGB'
85
95
assert mode is not None , '{} is not supported' .format (npimg .dtype )
86
-
87
96
return Image .fromarray (npimg , mode = mode )
88
97
89
98
0 commit comments