Skip to content

Commit d9c8384

Browse files
authored
Merge pull request #17 from ImYangYun/feat/minimal-loop-dtw
2 parents 1919d4d + bd2f7c8 commit d9c8384

File tree

11 files changed

+1164
-455
lines changed

11 files changed

+1164
-455
lines changed

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ dependencies = [
1414
"signwriting @ git+https://github.com/sign-language-processing/signwriting",
1515
"pose-anonymization @ git+https://github.com/sign-language-processing/pose-anonymization",
1616
"signwriting-evaluation @ git+https://github.com/sign-language-processing/signwriting-evaluation",
17+
"pose-evaluation @ git+https://github.com/sign-language-processing/pose-evaluation",
1718
"transformers>=4.25",
1819
"CAMDM @ git+https://github.com/AmitMY/CAMDM",
1920
]
@@ -30,10 +31,13 @@ column_limit = 120
3031

3132
[tool.pylint]
3233
max-line-length = 120
34+
ignore-paths = ["signwriting_animation/translation"]
3335
disable = [
3436
"C0114", # Missing module docstring
3537
"C0115", # Missing class docstring
3638
"C0116", # Missing function or method docstring
39+
"R0913", # Too many arguments
40+
"R0917", # Too many positional arguments
3741
]
3842

3943
[tool.setuptools]
Lines changed: 195 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,154 +1,233 @@
1+
# pylint: disable=too-many-instance-attributes,too-many-arguments,too-many-locals
12
import os
3+
import math
24
import random
3-
from dataclasses import dataclass
4-
from typing import Literal
5+
from typing import Literal, Optional
56
import pandas as pd
67
import torch
78
from torch.utils.data import Dataset, DataLoader
89
from pose_format.torch.masked.collator import zero_pad_collator
910
from pose_format.pose import Pose
10-
from pose_anonymization.data.normalization import normalize_mean_std
11+
from pose_format.utils.generic import reduce_holistic
12+
from pose_anonymization.data.normalization import pre_process_pose
1113
from signwriting_evaluation.metrics.clip import signwriting_to_clip_image
1214
from transformers import CLIPProcessor
1315

14-
@dataclass
15-
class DatasetConfig:
16+
17+
def _coalesce_maybe_nan(x) -> Optional[int]:
1618
"""
17-
Configuration for dataset paths and frame sampling.
19+
Convert NaN/None values to None, otherwise return the value.
20+
21+
Args:
22+
x: Value to check (can be None, NaN, or numeric)
23+
24+
Returns:
25+
None if input is None/NaN, otherwise the input value
1826
"""
19-
data_dir: str
20-
csv_path: str
21-
num_past_frames: int = 40
22-
num_future_frames: int = 20
23-
split: Literal['train', 'test', 'dev'] = 'train'
27+
if x is None:
28+
return None
29+
if isinstance(x, float) and math.isnan(x):
30+
return None
31+
return x
32+
2433

2534
class DynamicPosePredictionDataset(Dataset):
2635
"""
27-
A PyTorch Dataset for dynamic sampling of normalized pose sequences,
28-
conditioned on SignWriting images and optional scalar metadata.
29-
Each sample includes past and future pose segments, associated masks,
30-
and a CLIP-ready rendering of the SignWriting annotation.
36+
PyTorch Dataset for dynamic sampling of pose sequences conditioned on SignWriting.
37+
38+
This dataset provides past and future pose windows for training diffusion models.
39+
Data is returned in raw (unnormalized) format - normalization is handled by the
40+
LightningModule to ensure consistency with precomputed statistics.
41+
42+
Data Pipeline:
43+
Raw pose → reduce_holistic (586→178 keypoints) → pre_process_pose → return
44+
45+
Note: This preprocessing pipeline must match the one used to generate the
46+
normalization statistics (mean_std_178_with_preprocess.pt).
47+
48+
Args:
49+
data_dir: Root directory containing .pose files
50+
csv_path: Path to CSV file with pose metadata and SignWriting text
51+
num_past_frames: Number of past frames for conditioning (default: 60)
52+
num_future_frames: Number of future frames to predict (default: 30)
53+
with_metadata: Whether to include frame timing metadata (default: True)
54+
clip_model_name: HuggingFace model name for CLIP processor
55+
split: Data split to use ('train', 'dev', or 'test')
56+
use_reduce_holistic: Whether to reduce keypoints to 178 (default: True)
3157
"""
58+
3259
def __init__(
3360
self,
34-
config: DatasetConfig,
61+
data_dir: str,
62+
csv_path: str,
63+
num_past_frames: int = 40,
64+
num_future_frames: int = 20,
3565
with_metadata: bool = True,
3666
clip_model_name: str = "openai/clip-vit-base-patch32",
67+
split: Literal["train", "dev", "test"] = "train",
68+
use_reduce_holistic: bool = True,
3769
):
3870
super().__init__()
39-
assert config.split in ['train', 'test', 'dev']
40-
self.data_dir = config.data_dir
41-
self.num_past_frames = config.num_past_frames
42-
self.num_future_frames = config.num_future_frames
71+
assert split in ["train", "dev", "test"], f"Invalid split: {split}"
72+
73+
self.data_dir = data_dir
74+
self.num_past_frames = num_past_frames
75+
self.num_future_frames = num_future_frames
4376
self.with_metadata = with_metadata
44-
df_records = pd.read_csv(config.csv_path)
45-
df_records = df_records[df_records['split'] == config.split]
77+
self.use_reduce_holistic = use_reduce_holistic
78+
79+
self.mean_std = None
80+
81+
df_records = pd.read_csv(csv_path)
82+
df_records = df_records[df_records["split"] == split].reset_index(drop=True)
4683
self.records = df_records.to_dict(orient="records")
84+
4785
self.clip_processor = CLIPProcessor.from_pretrained(clip_model_name)
4886

49-
def __len__(self):
87+
def __len__(self) -> int:
5088
return len(self.records)
5189

52-
def _extract_pose_windows(self, pose):
90+
def __getitem__(self, idx: int) -> dict:
5391
"""
54-
Extract past and future windows from the pose object.
55-
Returns a dictionary with pose tensors and metadata.
92+
Load and process a single training sample.
93+
94+
Returns a dictionary containing:
95+
- data: Future pose sequence [T_future, J, C] (target for prediction)
96+
- conditions:
97+
- input_pose: Past pose sequence [T_past, J, C] (conditioning)
98+
- input_mask: Validity mask for past poses [T_past]
99+
- target_mask: Validity mask for future poses [T_future]
100+
- sign_image: CLIP-processed SignWriting image [3, H, W]
101+
- id: Sample identifier
102+
- metadata: (optional) Frame timing information
103+
104+
If the requested pose file is too short or corrupted, recursively tries
105+
the next sample to ensure training doesn't crash.
56106
"""
107+
rec = self.records[idx]
108+
109+
pose_path = os.path.join(self.data_dir, rec["pose"])
110+
if not pose_path.endswith(".pose"):
111+
pose_path += ".pose"
112+
113+
start = _coalesce_maybe_nan(rec.get("start"))
114+
end = _coalesce_maybe_nan(rec.get("end"))
115+
116+
if not os.path.exists(pose_path):
117+
raise FileNotFoundError(f"Pose file not found: {pose_path}")
118+
119+
# Load raw pose data
120+
with open(pose_path, "rb") as f:
121+
raw = Pose.read(f)
122+
123+
# Check if sequence is too short before preprocessing
124+
total_frames = len(raw.body.data)
125+
if total_frames < 5:
126+
print(f"[SKIP SHORT FILE] idx={idx} | total_frames={total_frames} | "
127+
f"file={os.path.basename(pose_path)}")
128+
return self.__getitem__((idx + 1) % len(self.records))
129+
130+
if self.use_reduce_holistic:
131+
raw = reduce_holistic(raw)
132+
raw = pre_process_pose(raw)
133+
pose = raw # Keep in raw scale (no normalization)
134+
135+
# Verify sequence is still valid after preprocessing
57136
total_frames = len(pose.body.data)
58-
pivot_frame = random.randint(0, total_frames - 1)
137+
if total_frames < 5:
138+
print(f"[SKIP SHORT CLIP] idx={idx} | total_frames={total_frames}")
139+
return self.__getitem__((idx + 1) % len(self.records))
59140

60-
input_start = max(0, pivot_frame - self.num_past_frames)
141+
# Sample time windows intelligently
142+
if total_frames <= (self.num_past_frames + self.num_future_frames + 2):
143+
# Short sequence: use centered sampling to maximize data usage
144+
pivot_frame = total_frames // 2
145+
input_start = max(0, pivot_frame - self.num_past_frames // 2)
146+
target_end = min(total_frames, input_start + self.num_past_frames + self.num_future_frames)
147+
else:
148+
# Long sequence: random sampling with proper boundaries
149+
pivot_min = self.num_past_frames
150+
pivot_max = total_frames - self.num_future_frames
151+
pivot_frame = random.randint(pivot_min, pivot_max)
152+
input_start = pivot_frame - self.num_past_frames
153+
target_end = pivot_frame + self.num_future_frames
154+
155+
# Extract pose windows
61156
input_pose = pose.body[input_start:pivot_frame].torch()
62-
target_end = min(total_frames, pivot_frame + self.num_future_frames)
63157
target_pose = pose.body[pivot_frame:target_end].torch()
64158

65-
return {
66-
"input_data": input_pose.data.zero_filled(),
67-
"target_data": target_pose.data.zero_filled(),
68-
"input_mask": input_pose.data.mask,
69-
"target_mask": target_pose.data.mask,
70-
"target_length": torch.tensor([len(target_pose.data)], dtype=torch.float32),
71-
"pivot_frame": pivot_frame,
72-
"target_end": target_end,
73-
"total_frames": total_frames,
74-
}
159+
# Debug logging for first few samples
160+
if idx < 3:
161+
print(f"[DEBUG SPLIT] idx={idx} | total={total_frames} | pivot={pivot_frame} | "
162+
f"input={input_start}:{pivot_frame} ({input_pose.data.shape[0]}f) | "
163+
f"target={pivot_frame}:{target_end} ({target_pose.data.shape[0]}f) | "
164+
f"file={os.path.basename(pose_path)}")
75165

76-
def _process_signwriting_image(self, text: str) -> torch.Tensor:
77-
pil_img = signwriting_to_clip_image(text)
78-
return self.clip_processor(images=pil_img, return_tensors="pt").pixel_values.squeeze(0)
166+
# Extract data and masks
167+
input_data = input_pose.data
168+
target_data = target_pose.data
169+
input_mask = input_pose.data.mask
170+
target_mask = target_pose.data.mask
79171

80-
def _build_sample_dict(self, info: dict):
172+
# Process SignWriting image through CLIP
173+
pil_img = signwriting_to_clip_image(rec.get("text", ""))
174+
sign_img = self.clip_processor(images=pil_img, return_tensors="pt").pixel_values.squeeze(0)
175+
176+
# Build output sample
81177
sample = {
82-
"data": info["target_data"],
178+
"data": target_data, # Future window (prediction target, unnormalized)
83179
"conditions": {
84-
"input_pose": info["input_data"],
85-
"input_mask": info["input_mask"],
86-
"target_mask": info["target_mask"],
87-
"sign_image": info["sign_img"],
180+
"input_pose": input_data, # Past window (conditioning, unnormalized)
181+
"input_mask": input_mask, # Validity mask for past frames
182+
"target_mask": target_mask, # Validity mask for future frames
183+
"sign_image": sign_img, # CLIP-processed SignWriting [3, H, W]
88184
},
89-
"id": info["rec"].get("id", os.path.basename(info["rec"]["pose"])),
90-
"length_target": info["target_length"],
185+
"id": rec.get("id", os.path.basename(rec["pose"])),
91186
}
92187

188+
# Add optional metadata for analysis
93189
if self.with_metadata:
94190
meta = {
95-
"total_frames": info["total_frames"],
96-
"sample_start": info["pivot_frame"],
97-
"sample_end": info["target_end"],
98-
"orig_start": info["rec"].get("start", 0),
99-
"orig_end": info["rec"].get("end", info["total_frames"]),
191+
"total_frames": total_frames,
192+
"sample_start": pivot_frame,
193+
"sample_end": pivot_frame + len(target_data),
194+
"orig_start": start or 0,
195+
"orig_end": end or total_frames,
100196
}
101197
sample["metadata"] = {
102-
k: torch.tensor([v], dtype=torch.long)
198+
k: torch.tensor([int(v)], dtype=torch.long)
103199
for k, v in meta.items()
104200
}
105201

106202
return sample
107203

108-
def __getitem__(self, idx):
109-
rec = self.records[idx]
110-
pose_path = os.path.join(self.data_dir, rec["pose"])
111-
112-
if not os.path.isfile(pose_path):
113-
return self[random.randint(0, len(self.records) - 1)]
114204

115-
with open(pose_path, "rb") as f:
116-
raw = Pose.read(
117-
f,
118-
start_time=rec.get("start") or None,
119-
end_time=rec.get("end") or None
120-
)
121-
122-
pose = normalize_mean_std(raw)
123-
window = self._extract_pose_windows(pose)
124-
sign_img = self._process_signwriting_image(rec.get("text", ""))
125-
126-
return self._build_sample_dict({
127-
**window,
128-
"sign_img": sign_img,
129-
"rec": rec,
130-
})
131-
132-
def get_num_workers():
205+
def get_num_workers() -> int:
133206
"""
134-
Determine appropriate number of workers based on CPU availability.
207+
Determine appropriate number of DataLoader workers based on CPU availability.
208+
209+
Returns:
210+
0 if CPU count is unavailable or ≤1, otherwise the CPU count
135211
"""
136212
cpu_count = os.cpu_count()
137213
return 0 if cpu_count is None or cpu_count <= 1 else cpu_count
138214

215+
139216
def main():
140-
config = DatasetConfig(
141-
data_dir="/scratch/yayun/pose_data/raw_poses",
142-
csv_path="/scratch/yayun/pose_data/data.csv",
143-
num_past_frames=40,
144-
num_future_frames=20,
145-
split='train'
146-
)
217+
"""Test dataset loading and print sample batch statistics."""
218+
data_dir = "/home/yayun/data/pose_data"
219+
csv_path = "/home/yayun/data/signwriting-animation/data_fixed.csv"
147220

148221
dataset = DynamicPosePredictionDataset(
149-
config=config,
222+
data_dir=data_dir,
223+
csv_path=csv_path,
224+
num_past_frames=60,
225+
num_future_frames=30,
150226
with_metadata=True,
227+
split="train",
228+
use_reduce_holistic=True,
151229
)
230+
152231
loader = DataLoader(
153232
dataset,
154233
batch_size=4,
@@ -158,15 +237,34 @@ def main():
158237
pin_memory=False,
159238
)
160239

240+
# Load and inspect a batch
161241
batch = next(iter(loader))
162-
print("Batch:", batch["data"].shape)
163-
print("Input pose:", batch["conditions"]["input_pose"].shape)
164-
print("Input mask:", batch["conditions"]["input_mask"].shape)
165-
print("Target mask:", batch["conditions"]["target_mask"].shape)
166-
print("Sign image:", batch["conditions"]["sign_image"].shape)
242+
print("Batch shapes:")
243+
print(f" Data (target): {batch['data'].shape}")
244+
print(f" Input pose: {batch['conditions']['input_pose'].shape}")
245+
print(f" Input mask: {batch['conditions']['input_mask'].shape}")
246+
print(f" Target mask: {batch['conditions']['target_mask'].shape}")
247+
print(f" Sign image: {batch['conditions']['sign_image'].shape}")
248+
249+
# Check data range (should be unnormalized)
250+
data = batch["data"]
251+
if hasattr(data, "tensor"):
252+
data = data.tensor
253+
print("\nData statistics (should be in raw range):")
254+
print(f" Min: {data.min().item():.4f}")
255+
print(f" Max: {data.max().item():.4f}")
256+
print(f" Mean: {data.mean().item():.4f}")
257+
print(f" Std: {data.std().item():.4f}")
258+
259+
if abs(data.mean().item()) < 0.1 and abs(data.std().item() - 1.0) < 0.2:
260+
print(" Warning: Data appears normalized (should be raw)")
261+
else:
262+
print(" Data is in raw range (correct)")
263+
167264
if "metadata" in batch:
265+
print("\nMetadata:")
168266
for k, v in batch["metadata"].items():
169-
print(f"Metadata {k}:", v.shape)
267+
print(f" {k}: {v.shape}")
170268

171-
# if __name__ == "__main__":
172-
# main()
269+
#if __name__ == "__main__":
270+
#main()

0 commit comments

Comments
 (0)