Skip to content

Commit 227e7ca

Browse files
feat: create dataloader file
1 parent f2e173f commit 227e7ca

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

utils/dataloader.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import os
2+
import torch
3+
from torchvision import datasets, transforms
4+
from torchvision.transforms import ToTensor, Resize, Compose
5+
from torch.utils.data import DataLoader, Dataset
6+
from PIL import Image
7+
8+
9+
def get_dataloader(data_path, batch_size):
10+
dataset = CustomDataset(data_path)
11+
12+
dataloader = DataLoader(
13+
dataset,
14+
batch_size=batch_size,
15+
shuffle=True
16+
)
17+
18+
return dataloader
19+
20+
21+
class CustomDataset(Dataset):
22+
def __init__(self, data_path):
23+
self.data_path = data_path
24+
self.image_files = os.listdir(data_path)
25+
26+
self.transforms = Compose([
27+
Resize((64, 64)),
28+
ToTensor()
29+
])
30+
31+
def __len__(self):
32+
return len(self.image_files)
33+
34+
def __getitem__(self, idx):
35+
image_path = os.path.join(self.data_path, self.image_files[idx])
36+
image = Image.open(image_path).convert('RGB')
37+
image = self.transforms(image)
38+
return image

0 commit comments

Comments
 (0)