-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathgcage_dataset.py
More file actions
107 lines (78 loc) · 3.17 KB
/
gcage_dataset.py
File metadata and controls
107 lines (78 loc) · 3.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import csv
from pathlib import Path
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as T
class GCageTrainDataset(Dataset):
def __init__(self, data_path: str):
"""
Args:
data_path: path to the data directory (csv)
CSV should have columns: image_path, heatmap_path, label
"""
self.data_path = Path(data_path)
self.samples = [] # list of dicts
self.to_tensor = T.ToTensor()
with open(self.data_path, newline="", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
row["label"] = int(row["label"])
self.samples.append(row)
if len(self.samples) == 0:
raise ValueError(f"No rows found in {self.data_path}")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
row = self.samples[idx]
img_path = Path(row["image_path"])
hm_path = Path(row["heatmap_path"])
label = row["label"]
img = Image.open(img_path).convert("RGB")
hm = Image.open(hm_path).convert("L")
img = self.to_tensor(img) # [3, H, W]
hm = self.to_tensor(hm) # [1, H, W]
return img, hm, label
class GCageEvalDataset(Dataset):
def __init__(self, data_path: str):
"""
Args:
data_path: path to the data directory (csv)
CSV should have columns: image_path, source_image_id, heatmap_path, label
"""
self.data_path = Path(data_path)
# Group samples by source_image_id
self.image_groups = {}
self.to_tensor = T.ToTensor()
with open(self.data_path, newline="", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
row["label"] = int(row["label"])
source_id = row.get("source_image_id")
if not source_id:
raise ValueError(f"source_image_id missing for {row['image_path']}")
if source_id not in self.image_groups:
self.image_groups[source_id] = []
self.image_groups[source_id].append(row)
self.unique_image_ids = list(self.image_groups.keys())
if len(self.unique_image_ids) == 0:
raise ValueError(f"No grouped images found in {self.data_path}")
def __len__(self):
return len(self.unique_image_ids)
def __getitem__(self, idx):
source_id = self.unique_image_ids[idx]
rows = self.image_groups[source_id]
imgs = []
hms = []
label = rows[0]["label"] # Label should be consistent for the whole image
for row in rows:
img_path = Path(row["image_path"])
hm_path = Path(row["heatmap_path"])
img = Image.open(img_path).convert("RGB")
hm = Image.open(hm_path).convert("L")
imgs.append(self.to_tensor(img))
hms.append(self.to_tensor(hm))
# Stack into [K, 3, H, W] and [K, 1, H, W]
imgs_tensor = torch.stack(imgs)
hms_tensor = torch.stack(hms)
return imgs_tensor, hms_tensor, label, source_id