@@ -23,6 +23,9 @@ class Omniglot(VisionDataset):
2323 download (bool, optional): If true, downloads the dataset zip files from the internet and
2424 puts it in root directory. If the zip files are already downloaded, they are not
2525 downloaded again.
26+ loader (callable, optional): A function to load an image given its path.
27+ By default, it uses PIL as its image loader, but users could also pass in
28+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
2629 """
2730
2831 folder = "omniglot-py"
@@ -39,6 +42,7 @@ def __init__(
3942 transform : Optional [Callable ] = None ,
4043 target_transform : Optional [Callable ] = None ,
4144 download : bool = False ,
45+ loader : Optional [Callable [[Union [str , Path ]], Any ]] = None ,
4246 ) -> None :
4347 super ().__init__ (join (root , self .folder ), transform = transform , target_transform = target_transform )
4448 self .background = background
@@ -59,6 +63,7 @@ def __init__(
5963 for idx , character in enumerate (self ._characters )
6064 ]
6165 self ._flat_character_images : List [Tuple [str , int ]] = sum (self ._character_images , [])
66+ self .loader = loader
6267
6368 def __len__ (self ) -> int :
6469 return len (self ._flat_character_images )
@@ -73,7 +78,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]:
7378 """
7479 image_name , character_class = self ._flat_character_images [index ]
7580 image_path = join (self .target_folder , self ._characters [character_class ], image_name )
76- image = Image .open (image_path , mode = "r" ).convert ("L" )
81+ image = Image .open (image_path , mode = "r" ).convert ("L" ) if self . loader is None else self . loader ( image_path )
7782
7883 if self .transform :
7984 image = self .transform (image )
0 commit comments