1515from PIL import __version__ as PILLOW_VERSION , Image , ImageOps , ImageSequence
1616from torchvision .io .image import (
1717 _decode_avif ,
18+ _decode_heic ,
1819 decode_gif ,
1920 decode_image ,
2021 decode_jpeg ,
4142IS_WINDOWS = sys .platform in ("win32" , "cygwin" )
4243IS_MACOS = sys .platform == "darwin"
4344PILLOW_VERSION = tuple (int (x ) for x in PILLOW_VERSION .split ("." ))
45+ WEBP_TEST_IMAGES_DIR = os .environ .get ("WEBP_TEST_IMAGES_DIR" , "" )
46+
47+ # Hacky way of figuring out whether we compiled with libavif/libheif (those are
48+ # currenlty disabled by default)
49+ try :
50+ _decode_avif (torch .arange (10 , dtype = torch .uint8 ))
51+ except Exception as e :
52+ DECODE_AVIF_ENABLED = "torchvision not compiled with libavif support" not in str (e )
53+
54+ try :
55+ _decode_heic (torch .arange (10 , dtype = torch .uint8 ))
56+ except Exception as e :
57+ DECODE_HEIC_ENABLED = "torchvision not compiled with libheif support" not in str (e )
4458
4559
4660def _get_safe_image_name (name ):
@@ -148,17 +162,6 @@ def test_invalid_exif(tmpdir, size):
148162 torch .testing .assert_close (expected , output )
149163
150164
151- def test_decode_jpeg_errors ():
152- with pytest .raises (RuntimeError , match = "Expected a non empty 1-dimensional tensor" ):
153- decode_jpeg (torch .empty ((100 , 1 ), dtype = torch .uint8 ))
154-
155- with pytest .raises (RuntimeError , match = "Expected a torch.uint8 tensor" ):
156- decode_jpeg (torch .empty ((100 ,), dtype = torch .float16 ))
157-
158- with pytest .raises (RuntimeError , match = "Not a JPEG file" ):
159- decode_jpeg (torch .empty ((100 ), dtype = torch .uint8 ))
160-
161-
162165def test_decode_bad_huffman_images ():
163166 # sanity check: make sure we can decode the bad Huffman encoding
164167 bad_huff = read_file (os .path .join (DAMAGED_JPEG , "bad_huffman.jpg" ))
@@ -234,10 +237,6 @@ def test_decode_png(img_path, pil_mode, mode, scripted, decode_fun):
234237
235238
236239def test_decode_png_errors ():
237- with pytest .raises (RuntimeError , match = "Expected a non empty 1-dimensional tensor" ):
238- decode_png (torch .empty ((), dtype = torch .uint8 ))
239- with pytest .raises (RuntimeError , match = "Content is not png" ):
240- decode_png (torch .randint (3 , 5 , (300 ,), dtype = torch .uint8 ))
241240 with pytest .raises (RuntimeError , match = "Out of bound read in decode_png" ):
242241 decode_png (read_file (os .path .join (DAMAGED_PNG , "sigsegv.png" )))
243242 with pytest .raises (RuntimeError , match = "Content is too small for png" ):
@@ -863,20 +862,28 @@ def test_decode_gif(tmpdir, name, scripted):
863862 torch .testing .assert_close (tv_frame , pil_frame , atol = 0 , rtol = 0 )
864863
865864
866- @pytest .mark .parametrize ("decode_fun" , (decode_gif , decode_webp ))
867- def test_decode_gif_webp_errors (decode_fun ):
865+ decode_fun_and_match = [
866+ (decode_png , "Content is not png" ),
867+ (decode_jpeg , "Not a JPEG file" ),
868+ (decode_gif , re .escape ("DGifOpenFileName() failed - 103" )),
869+ (decode_webp , "WebPGetFeatures failed." ),
870+ ]
871+ if DECODE_AVIF_ENABLED :
872+ decode_fun_and_match .append ((_decode_avif , "BMFF parsing failed" ))
873+ if DECODE_HEIC_ENABLED :
874+ decode_fun_and_match .append ((_decode_heic , "Invalid input: No 'ftyp' box" ))
875+
876+
877+ @pytest .mark .parametrize ("decode_fun, match" , decode_fun_and_match )
878+ def test_decode_bad_encoded_data (decode_fun , match ):
868879 encoded_data = torch .randint (0 , 256 , (100 ,), dtype = torch .uint8 )
869880 with pytest .raises (RuntimeError , match = "Input tensor must be 1-dimensional" ):
870881 decode_fun (encoded_data [None ])
871882 with pytest .raises (RuntimeError , match = "Input tensor must have uint8 data type" ):
872883 decode_fun (encoded_data .float ())
873884 with pytest .raises (RuntimeError , match = "Input tensor must be contiguous" ):
874885 decode_fun (encoded_data [::2 ])
875- if decode_fun is decode_gif :
876- expected_match = re .escape ("DGifOpenFileName() failed - 103" )
877- elif decode_fun is decode_webp :
878- expected_match = "WebPGetFeatures failed."
879- with pytest .raises (RuntimeError , match = expected_match ):
886+ with pytest .raises (RuntimeError , match = match ):
880887 decode_fun (encoded_data )
881888
882889
@@ -889,21 +896,27 @@ def test_decode_webp(decode_fun, scripted):
889896 img = decode_fun (encoded_bytes )
890897 assert img .shape == (3 , 100 , 100 )
891898 assert img [None ].is_contiguous (memory_format = torch .channels_last )
899+ img += 123 # make sure image buffer wasn't freed by underlying decoding lib
892900
893901
894- # This test is skipped because it requires webp images that we're not including
895- # within the repo. The test images were downloaded from the different pages of
896- # https://developers.google.com/speed/webp/gallery
897- # Note that converting an RGBA image to RGB leads to bad results because the
898- # transparent pixels aren't necessarily set to "black" or "white", they can be
899- # random stuff. This is consistent with PIL results.
900- @pytest .mark .skip (reason = "Need to download test images first" )
902+ # This test is skipped by default because it requires webp images that we're not
903+ # including within the repo. The test images were downloaded manually from the
904+ # different pages of https://developers.google.com/speed/webp/gallery
905+ @pytest .mark .skipif (not WEBP_TEST_IMAGES_DIR , reason = "WEBP_TEST_IMAGES_DIR is not set" )
901906@pytest .mark .parametrize ("decode_fun" , (decode_webp , decode_image ))
902907@pytest .mark .parametrize ("scripted" , (False , True ))
903908@pytest .mark .parametrize (
904- "mode, pil_mode" , ((ImageReadMode .RGB , "RGB" ), (ImageReadMode .RGB_ALPHA , "RGBA" ), (ImageReadMode .UNCHANGED , None ))
909+ "mode, pil_mode" ,
910+ (
911+ # Note that converting an RGBA image to RGB leads to bad results because the
912+ # transparent pixels aren't necessarily set to "black" or "white", they can be
913+ # random stuff. This is consistent with PIL results.
914+ (ImageReadMode .RGB , "RGB" ),
915+ (ImageReadMode .RGB_ALPHA , "RGBA" ),
916+ (ImageReadMode .UNCHANGED , None ),
917+ ),
905918)
906- @pytest .mark .parametrize ("filename" , Path ("/home/nicolashug/webp_samples" ).glob ("*.webp" ))
919+ @pytest .mark .parametrize ("filename" , Path (WEBP_TEST_IMAGES_DIR ).glob ("*.webp" ), ids = lambda p : p . name )
907920def test_decode_webp_against_pil (decode_fun , scripted , mode , pil_mode , filename ):
908921 encoded_bytes = read_file (filename )
909922 if scripted :
@@ -914,9 +927,10 @@ def test_decode_webp_against_pil(decode_fun, scripted, mode, pil_mode, filename)
914927 pil_img = Image .open (filename ).convert (pil_mode )
915928 from_pil = F .pil_to_tensor (pil_img )
916929 assert_equal (img , from_pil )
930+ img += 123 # make sure image buffer wasn't freed by underlying decoding lib
917931
918932
919- @pytest .mark .xfail ( reason = "AVIF support not enabled yet ." )
933+ @pytest .mark .skipif ( not DECODE_AVIF_ENABLED , reason = "AVIF support not enabled." )
920934@pytest .mark .parametrize ("decode_fun" , (_decode_avif , decode_image ))
921935@pytest .mark .parametrize ("scripted" , (False , True ))
922936def test_decode_avif (decode_fun , scripted ):
@@ -926,13 +940,20 @@ def test_decode_avif(decode_fun, scripted):
926940 img = decode_fun (encoded_bytes )
927941 assert img .shape == (3 , 100 , 100 )
928942 assert img [None ].is_contiguous (memory_format = torch .channels_last )
943+ img += 123 # make sure image buffer wasn't freed by underlying decoding lib
929944
930945
931- @pytest .mark .xfail (reason = "AVIF support not enabled yet." )
932946# Note: decode_image fails because some of these files have a (valid) signature
933947# we don't recognize. We should probably use libmagic....
934- # @pytest.mark.parametrize("decode_fun", (_decode_avif, decode_image))
935- @pytest .mark .parametrize ("decode_fun" , (_decode_avif ,))
948+ decode_funs = []
949+ if DECODE_AVIF_ENABLED :
950+ decode_funs .append (_decode_avif )
951+ if DECODE_HEIC_ENABLED :
952+ decode_funs .append (_decode_heic )
953+
954+
955+ @pytest .mark .skipif (not decode_funs , reason = "Built without avif and heic support." )
956+ @pytest .mark .parametrize ("decode_fun" , decode_funs )
936957@pytest .mark .parametrize ("scripted" , (False , True ))
937958@pytest .mark .parametrize (
938959 "mode, pil_mode" ,
@@ -942,8 +963,10 @@ def test_decode_avif(decode_fun, scripted):
942963 (ImageReadMode .UNCHANGED , None ),
943964 ),
944965)
945- @pytest .mark .parametrize ("filename" , Path ("/home/nicolashug/dev/libavif/tests/data/" ).glob ("*.avif" ))
946- def test_decode_avif_against_pil (decode_fun , scripted , mode , pil_mode , filename ):
966+ @pytest .mark .parametrize (
967+ "filename" , Path ("/home/nicolashug/dev/libavif/tests/data/" ).glob ("*.avif" ), ids = lambda p : p .name
968+ )
969+ def test_decode_avif_heic_against_pil (decode_fun , scripted , mode , pil_mode , filename ):
947970 if "reversed_dimg_order" in str (filename ):
948971 # Pillow properly decodes this one, but we don't (order of parts of the
949972 # image is wrong). This is due to a bug that was recently fixed in
@@ -960,7 +983,14 @@ def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename)
960983 except RuntimeError as e :
961984 if any (
962985 s in str (e )
963- for s in ("BMFF parsing failed" , "avifDecoderParse failed: " , "file contains more than one image" )
986+ for s in (
987+ "BMFF parsing failed" ,
988+ "avifDecoderParse failed: " ,
989+ "file contains more than one image" ,
990+ "no 'ispe' property" ,
991+ "'iref' has double references" ,
992+ "Invalid image grid" ,
993+ )
964994 ):
965995 pytest .skip (reason = "Expected failure, that's OK" )
966996 else :
@@ -970,22 +1000,48 @@ def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename)
9701000 assert img .shape [0 ] == 3
9711001 if mode == ImageReadMode .RGB_ALPHA :
9721002 assert img .shape [0 ] == 4
1003+
9731004 if img .dtype == torch .uint16 :
9741005 img = F .to_dtype (img , dtype = torch .uint8 , scale = True )
1006+ try :
1007+ from_pil = F .pil_to_tensor (Image .open (filename ).convert (pil_mode ))
1008+ except RuntimeError as e :
1009+ if "Invalid image grid" in str (e ):
1010+ pytest .skip (reason = "PIL failure" )
1011+ else :
1012+ raise e
9751013
976- from_pil = F .pil_to_tensor (Image .open (filename ).convert (pil_mode ))
977- if False :
1014+ if True :
9781015 from torchvision .utils import make_grid
9791016
9801017 g = make_grid ([img , from_pil ])
9811018 F .to_pil_image (g ).save ((f"/home/nicolashug/out_images/{ filename .name } .{ pil_mode } .png" ))
982- if mode != ImageReadMode .RGB :
983- # We don't compare against PIL for RGB because results look pretty
984- # different on RGBA images (other images are fine). The result on
985- # torchvision basically just plainly ignores the alpha channel, resuting
986- # in transparent pixels looking dark. PIL seems to be using a sort of
987- # k-nn thing, looking at the output. Take a look at the resuting images.
988- torch .testing .assert_close (img , from_pil , rtol = 0 , atol = 3 )
1019+
1020+ is_decode_heic = getattr (decode_fun , "__name__" , getattr (decode_fun , "name" , None )) == "_decode_heic"
1021+ if mode == ImageReadMode .RGB and not is_decode_heic :
1022+ # We don't compare torchvision's AVIF against PIL for RGB because
1023+ # results look pretty different on RGBA images (other images are fine).
1024+ # The result on torchvision basically just plainly ignores the alpha
1025+ # channel, resuting in transparent pixels looking dark. PIL seems to be
1026+ # using a sort of k-nn thing (Take a look at the resuting images)
1027+ return
1028+ if filename .name == "sofa_grid1x5_420.avif" and is_decode_heic :
1029+ return
1030+
1031+ torch .testing .assert_close (img , from_pil , rtol = 0 , atol = 3 )
1032+
1033+
1034+ @pytest .mark .skipif (not DECODE_HEIC_ENABLED , reason = "HEIC support not enabled yet." )
1035+ @pytest .mark .parametrize ("decode_fun" , (_decode_heic , decode_image ))
1036+ @pytest .mark .parametrize ("scripted" , (False , True ))
1037+ def test_decode_heic (decode_fun , scripted ):
1038+ encoded_bytes = read_file (next (get_images (FAKEDATA_DIR , ".heic" )))
1039+ if scripted :
1040+ decode_fun = torch .jit .script (decode_fun )
1041+ img = decode_fun (encoded_bytes )
1042+ assert img .shape == (3 , 100 , 100 )
1043+ assert img [None ].is_contiguous (memory_format = torch .channels_last )
1044+ img += 123 # make sure image buffer wasn't freed by underlying decoding lib
9891045
9901046
9911047if __name__ == "__main__" :
0 commit comments