Skip to content

Commit f7bd2e4

Browse files
committed
Use linear RGB downscaling for most downscaling operations.
Didn't include nearest_aligned since it shouldn't matter when there's no blending of pixels going on.
1 parent 12d006f commit f7bd2e4

File tree

3 files changed

+80
-21
lines changed

3 files changed

+80
-21
lines changed

codes/dataops/augmennt/augmennt/functional.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import collections
1616
import warnings
1717

18+
from ...colors import linear2srgb, srgb2linear
1819
from .common import preserve_channel_dim, preserve_shape
1920
from .common import _cv2_str2pad, _cv2_str2interpolation
2021

@@ -168,13 +169,14 @@ def resize(img, size, interpolation='BILINEAR'):
168169
raise TypeError('img should be numpy image. Got {}'.format(type(img)))
169170
if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)):
170171
raise TypeError('Got inappropriate size arg: {}'.format(size))
171-
172+
172173
w, h, = size
173174
if isinstance(size, int):
174175
# h, w, c = img.shape #this would defeat the purpose of "size"
175-
176+
176177
if (w <= h and w == size) or (h <= w and h == size):
177178
return img
179+
img = srgb2linear(img)
178180
if w < h:
179181
ow = size
180182
oh = int(size * h / w)
@@ -184,9 +186,10 @@ def resize(img, size, interpolation='BILINEAR'):
184186
ow = int(size * w / h)
185187
output = cv2.resize(img, dsize=(ow, oh), interpolation=_cv2_str2interpolation[interpolation])
186188
else:
189+
img = srgb2linear(img)
187190
output = cv2.resize(img, dsize=(size[1], size[0]), interpolation=_cv2_str2interpolation[interpolation])
188-
189-
return output
191+
192+
return linear2srgb(output)
190193

191194

192195
def scale(*args, **kwargs):

codes/dataops/augmentations.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
import numpy as np
77
import dataops.common as util
8-
from dataops.common import fix_img_channels, get_image_paths, read_img, np2tensor
8+
from dataops.colors import linear2srgb, srgb2linear
9+
from dataops.common import fix_img_channels, np2tensor
910
from dataops.debug import *
1011
from dataops.imresize import resize as imresize # resize # imresize_np
1112

@@ -202,20 +203,24 @@ def __call__(self, img:np.ndarray) -> np.ndarray:
202203
if len(self.out_shape) < 3:
203204
self.out_shape = self.out_shape + (image_channels(img),)
204205

206+
img = srgb2linear(img)
207+
205208
if self.kind == 'transforms':
206209
if self.out_shape:
207-
return resize(
208-
np.copy(img),
210+
img = resize(
211+
img,
209212
w=self.out_shape[1], h=self.out_shape[0],
210213
method=self.interpolation)
211-
return scale_(
212-
np.copy(img), self.scale, method=self.interpolation)
213-
scale = None if self.out_shape else 1/self.scale
214-
# return imresize_np(
215-
# np.copy(img), scale=scale, antialiasing=self.antialiasing, interpolation=self.interpolation)
216-
return imresize(
217-
np.copy(img), scale, out_shape=self.out_shape,
218-
antialiasing=self.antialiasing, interpolation=self.interpolation)
214+
else:
215+
img = scale_(
216+
img, self.scale, method=self.interpolation)
217+
else:
218+
scale = None if self.out_shape else 1/self.scale
219+
img = imresize(
220+
img, scale, out_shape=self.out_shape,
221+
antialiasing=self.antialiasing, interpolation=self.interpolation)
222+
223+
return linear2srgb(img)
219224

220225

221226
def get_resize(size=None, scale=None, ds_algo=None,

codes/dataops/colors.py

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -193,14 +193,65 @@ def yuv_to_rgb(input: torch.Tensor, consts='yuv') -> torch.Tensor:
193193
b: torch.Tensor = y + Wb * u_shifted
194194
return torch.stack((r, g, b), -3)
195195

196-
# Not tested:
197-
def rgb2srgb(imgs):
198-
return torch.where(imgs<=0.04045,imgs/12.92,torch.pow((imgs+0.055)/1.055,2.4))
199196

200-
# Not tested:
201-
def srgb2rgb(imgs):
202-
return torch.where(imgs<=0.0031308,imgs*12.92,1.055*torch.pow((imgs),1/2.4)-0.055)
197+
def srgb2linear(img):
198+
"""Convert sRGB images to linear RGB color space.
199+
Tensors are left as f32 in the range [0, 1].
200+
Uint8 numpy arrays are converted from uint8 in the range [0, 255]
201+
to f32 in the range [0, 1].
202+
F32 numpy arrays are assumed to be already be linear RGB.
203+
Always returns a new array.
204+
All values are exact as per the sRGB spec.
205+
"""
206+
a = 0.055
207+
att = 12.92
208+
gamma = 2.4
209+
th = 0.04045
210+
211+
if isinstance(img, torch.Tensor):
212+
return torch.where(
213+
img <= th, img / att, torch.pow((img + a)/(1 + a), gamma))
214+
215+
if img.dtype == np.uint8:
216+
linear = np.float32(img) / 255.0
217+
218+
return np.where(
219+
linear <= th, linear / att, np.power((linear + a) / (1 + a), gamma))
220+
221+
return img.copy()
222+
223+
224+
def linear2srgb(img):
225+
"""Convert linear RGB to the sRGB colour space.
226+
Tensors are left as f32 in the range [0, 1].
227+
F32 numpy arrays are converted back to the expected uint8 format
228+
in the range [0, 255].
229+
Uint8 numpy arrays are assumed to already be sRGB.
230+
Always returns a new array.
231+
All values are exact as per the sRGB spec.
232+
"""
233+
a = 0.055
234+
att = 12.92
235+
gamma = 2.4
236+
th = 0.0031308
237+
238+
if isinstance(img, torch.Tensor):
239+
return torch.where(
240+
img <= th,
241+
img * att, (1 + a) * torch.pow((img), 1 / gamma) - a)
242+
243+
if img.dtype == np.float32:
244+
srgb = np.clip(img, 0.0, 1.0)
245+
246+
srgb = np.where(
247+
srgb <= th, srgb * att, (1 + a) * np.power(srgb, 1.0 / gamma) - a)
248+
249+
np.clip(srgb * 255, 0.0, 255, out=srgb)
250+
np.around(srgb, out=srgb)
251+
252+
return srgb.astype(np.uint8)
203253

254+
return img.copy()
204255

205256

206257
def color_shift(image: torch.Tensor, mode:str='uniform',

0 commit comments

Comments
 (0)