|
| 1 | +import torch |
| 2 | +import torch.utils.data as data |
| 3 | +from .. import transforms |
| 4 | + |
| 5 | +class FakeData(data.Dataset): |
| 6 | + """A fake dataset that returns randomly generated images and returns them as PIL images |
| 7 | +
|
| 8 | + Args: |
| 9 | + size (int, optional): Size of the dataset. Default: 1000 images |
| 10 | + image_size(tuple, optional): Size if the returned images. Default: (3, 224, 224) |
| 11 | + num_classes(int, optional): Number of classes in the datset. Default: 10 |
| 12 | + transform (callable, optional): A function/transform that takes in an PIL image |
| 13 | + and returns a transformed version. E.g, ``transforms.RandomCrop`` |
| 14 | + target_transform (callable, optional): A function/transform that takes in the |
| 15 | + target and transforms it. |
| 16 | +
|
| 17 | + """ |
| 18 | + |
| 19 | + def __init__(self, size=1000, image_size=(3, 224, 224), num_classes=10, transform=None, target_transform=None): |
| 20 | + self.size = size |
| 21 | + self.num_classes = num_classes |
| 22 | + self.image_size = image_size |
| 23 | + self.transform = transform |
| 24 | + self.target_transform = target_transform |
| 25 | + |
| 26 | + def __getitem__(self, index): |
| 27 | + """ |
| 28 | + Args: |
| 29 | + index (int): Index |
| 30 | +
|
| 31 | + Returns: |
| 32 | + tuple: (image, target) where target is class_index of the target class. |
| 33 | + """ |
| 34 | + # create random image that is consistent with the index id |
| 35 | + rng_state = torch.get_rng_state() |
| 36 | + torch.manual_seed(index) |
| 37 | + img = torch.randn(*self.image_size) |
| 38 | + target = torch.Tensor(1).random_(0, self.num_classes)[0] |
| 39 | + torch.set_rng_state(rng_state) |
| 40 | + |
| 41 | + # convert to PIL Image |
| 42 | + img = transforms.ToPILImage()(img) |
| 43 | + if self.transform is not None: |
| 44 | + img = self.transform(img) |
| 45 | + if self.target_transform is not None: |
| 46 | + target = self.target_transform(target) |
| 47 | + |
| 48 | + return img, target |
| 49 | + |
| 50 | + def __len__(self): |
| 51 | + return self.size |
0 commit comments