11from pathlib import Path
2- from typing import Callable , Optional , Union
2+ from typing import Any , Callable , Optional , Union
33
4- from .folder import ImageFolder
4+ from .folder import default_loader , ImageFolder
55from .utils import download_and_extract_archive , verify_str_arg
66
77
@@ -21,6 +21,9 @@ class Country211(ImageFolder):
2121 target_transform (callable, optional): A function/transform that takes in the target and transforms it.
2222 download (bool, optional): If True, downloads the dataset from the internet and puts it into
2323 ``root/country211/``. If dataset is already downloaded, it is not downloaded again.
24+ loader (callable, optional): A function to load an image given its path.
25+ By default, it uses PIL as its image loader, but users could also pass in
26+ ``torchvision.io.decode_image`` for decoding image data into tensors directly.
2427 """
2528
2629 _URL = "https://openaipublic.azureedge.net/clip/data/country211.tgz"
@@ -33,6 +36,7 @@ def __init__(
3336 transform : Optional [Callable ] = None ,
3437 target_transform : Optional [Callable ] = None ,
3538 download : bool = False ,
39+ loader : Callable [[str ], Any ] = default_loader ,
3640 ) -> None :
3741 self ._split = verify_str_arg (split , "split" , ("train" , "valid" , "test" ))
3842
@@ -46,7 +50,12 @@ def __init__(
4650 if not self ._check_exists ():
4751 raise RuntimeError ("Dataset not found. You can use download=True to download it" )
4852
49- super ().__init__ (str (self ._base_folder / self ._split ), transform = transform , target_transform = target_transform )
53+ super ().__init__ (
54+ str (self ._base_folder / self ._split ),
55+ transform = transform ,
56+ target_transform = target_transform ,
57+ loader = loader ,
58+ )
5059 self .root = str (root )
5160
5261 def _check_exists (self ) -> bool :
0 commit comments