Skip to content

Commit 65a2ec0

Browse files
authored
Create an ElasticImageFolder for PyTorch. (#2486)
* Develop an image folder dataset for PyTorch * Add docstring
1 parent 7bdfebf commit 65a2ec0

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed

elasticai_api/pytorch/dataset.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright 2021 The ElasticDL Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
import sys
15+
from typing import Any, Tuple
16+
17+
import torchvision
18+
from torchvision.datasets.folder import default_loader
19+
20+
IMG_EXTENSIONS = (
21+
".jpg",
22+
".jpeg",
23+
".png",
24+
".ppm",
25+
".bmp",
26+
".pgm",
27+
".tif",
28+
".tiff",
29+
".webp",
30+
)
31+
32+
33+
class ElasticImageFolder(torchvision.datasets.ImageFolder):
34+
def __init__(
35+
self,
36+
root,
37+
transform=None,
38+
target_transform=None,
39+
loader=default_loader,
40+
is_valid_file=None,
41+
):
42+
"""Create a dataset from a folder for ElasticDL
43+
Arguments:
44+
root: the path of the image folder
45+
transform (callable, optional): A function/transform that takes in
46+
a sample and returns a transformed version.
47+
E.g, ``transforms.RandomCrop`` for images.
48+
target_transform (callable, optional): A function/transform that
49+
takes in the target and transforms it.
50+
loader (callable): A function to load a sample given its path.
51+
is_valid_file (callable, optional): A function that takes path of
52+
a file and check if the file is a valid file (used to check of
53+
corrupt files) both extensions and is_valid_file should not
54+
be passed.
55+
"""
56+
super(ElasticImageFolder, self).__init__(
57+
root,
58+
transform=transform,
59+
target_transform=target_transform,
60+
loader=loader,
61+
is_valid_file=is_valid_file,
62+
)
63+
self._data_shard_service = None
64+
65+
def set_data_shard_service(self, data_shard_service):
66+
self._data_shard_service = data_shard_service
67+
68+
def __len__(self):
69+
if self._data_shard_service:
70+
# Set the maxsize because the size of dataset is not fixed
71+
# when using dynamic sharding
72+
return sys.maxsize
73+
else:
74+
return len(self.samples)
75+
76+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
77+
"""
78+
Args:
79+
index (int): Index
80+
Returns:
81+
tuple: (sample, target) where target is
82+
class_index of the target class.
83+
"""
84+
if self._data_shard_service:
85+
index = self._data_shard_service.fetch_record_index()
86+
path, target = self.samples[index]
87+
sample = self.loader(path)
88+
if self.transform is not None:
89+
sample = self.transform(sample)
90+
if self.target_transform is not None:
91+
target = self.target_transform(target)
92+
return sample, target

0 commit comments

Comments
 (0)