1+ # pylint: disable=too-many-instance-attributes,too-many-arguments,too-many-locals
12import os
3+ import math
24import random
3- from dataclasses import dataclass
4- from typing import Literal
5+ from typing import Literal , Optional
56import pandas as pd
67import torch
78from torch .utils .data import Dataset , DataLoader
89from pose_format .torch .masked .collator import zero_pad_collator
910from pose_format .pose import Pose
10- from pose_anonymization .data .normalization import normalize_mean_std
11+ from pose_format .utils .generic import reduce_holistic
12+ from pose_anonymization .data .normalization import pre_process_pose
1113from signwriting_evaluation .metrics .clip import signwriting_to_clip_image
1214from transformers import CLIPProcessor
1315
14- @ dataclass
15- class DatasetConfig :
16+
17+ def _coalesce_maybe_nan ( x ) -> Optional [ int ] :
1618 """
17- Configuration for dataset paths and frame sampling.
19+ Convert NaN/None values to None, otherwise return the value.
20+
21+ Args:
22+ x: Value to check (can be None, NaN, or numeric)
23+
24+ Returns:
25+ None if input is None/NaN, otherwise the input value
1826 """
19- data_dir : str
20- csv_path : str
21- num_past_frames : int = 40
22- num_future_frames : int = 20
23- split : Literal ['train' , 'test' , 'dev' ] = 'train'
27+ if x is None :
28+ return None
29+ if isinstance (x , float ) and math .isnan (x ):
30+ return None
31+ return x
32+
2433
2534class DynamicPosePredictionDataset (Dataset ):
2635 """
27- A PyTorch Dataset for dynamic sampling of normalized pose sequences,
28- conditioned on SignWriting images and optional scalar metadata.
29- Each sample includes past and future pose segments, associated masks,
30- and a CLIP-ready rendering of the SignWriting annotation.
36+ PyTorch Dataset for dynamic sampling of pose sequences conditioned on SignWriting.
37+
38+ This dataset provides past and future pose windows for training diffusion models.
39+ Data is returned in raw (unnormalized) format - normalization is handled by the
40+ LightningModule to ensure consistency with precomputed statistics.
41+
42+ Data Pipeline:
43+ Raw pose → reduce_holistic (586→178 keypoints) → pre_process_pose → return
44+
45+ Note: This preprocessing pipeline must match the one used to generate the
46+ normalization statistics (mean_std_178_with_preprocess.pt).
47+
48+ Args:
49+ data_dir: Root directory containing .pose files
50+ csv_path: Path to CSV file with pose metadata and SignWriting text
51+ num_past_frames: Number of past frames for conditioning (default: 60)
52+ num_future_frames: Number of future frames to predict (default: 30)
53+ with_metadata: Whether to include frame timing metadata (default: True)
54+ clip_model_name: HuggingFace model name for CLIP processor
55+ split: Data split to use ('train', 'dev', or 'test')
56+ use_reduce_holistic: Whether to reduce keypoints to 178 (default: True)
3157 """
58+
3259 def __init__ (
3360 self ,
34- config : DatasetConfig ,
61+ data_dir : str ,
62+ csv_path : str ,
63+ num_past_frames : int = 40 ,
64+ num_future_frames : int = 20 ,
3565 with_metadata : bool = True ,
3666 clip_model_name : str = "openai/clip-vit-base-patch32" ,
67+ split : Literal ["train" , "dev" , "test" ] = "train" ,
68+ use_reduce_holistic : bool = True ,
3769 ):
3870 super ().__init__ ()
39- assert config .split in ['train' , 'test' , 'dev' ]
40- self .data_dir = config .data_dir
41- self .num_past_frames = config .num_past_frames
42- self .num_future_frames = config .num_future_frames
71+ assert split in ["train" , "dev" , "test" ], f"Invalid split: { split } "
72+
73+ self .data_dir = data_dir
74+ self .num_past_frames = num_past_frames
75+ self .num_future_frames = num_future_frames
4376 self .with_metadata = with_metadata
44- df_records = pd .read_csv (config .csv_path )
45- df_records = df_records [df_records ['split' ] == config .split ]
77+ self .use_reduce_holistic = use_reduce_holistic
78+
79+ self .mean_std = None
80+
81+ df_records = pd .read_csv (csv_path )
82+ df_records = df_records [df_records ["split" ] == split ].reset_index (drop = True )
4683 self .records = df_records .to_dict (orient = "records" )
84+
4785 self .clip_processor = CLIPProcessor .from_pretrained (clip_model_name )
4886
49- def __len__ (self ):
87+ def __len__ (self ) -> int :
5088 return len (self .records )
5189
52- def _extract_pose_windows (self , pose ) :
90+ def __getitem__ (self , idx : int ) -> dict :
5391 """
54- Extract past and future windows from the pose object.
55- Returns a dictionary with pose tensors and metadata.
92+ Load and process a single training sample.
93+
94+ Returns a dictionary containing:
95+ - data: Future pose sequence [T_future, J, C] (target for prediction)
96+ - conditions:
97+ - input_pose: Past pose sequence [T_past, J, C] (conditioning)
98+ - input_mask: Validity mask for past poses [T_past]
99+ - target_mask: Validity mask for future poses [T_future]
100+ - sign_image: CLIP-processed SignWriting image [3, H, W]
101+ - id: Sample identifier
102+ - metadata: (optional) Frame timing information
103+
104+ If the requested pose file is too short or corrupted, recursively tries
105+ the next sample to ensure training doesn't crash.
56106 """
107+ rec = self .records [idx ]
108+
109+ pose_path = os .path .join (self .data_dir , rec ["pose" ])
110+ if not pose_path .endswith (".pose" ):
111+ pose_path += ".pose"
112+
113+ start = _coalesce_maybe_nan (rec .get ("start" ))
114+ end = _coalesce_maybe_nan (rec .get ("end" ))
115+
116+ if not os .path .exists (pose_path ):
117+ raise FileNotFoundError (f"Pose file not found: { pose_path } " )
118+
119+ # Load raw pose data
120+ with open (pose_path , "rb" ) as f :
121+ raw = Pose .read (f )
122+
123+ # Check if sequence is too short before preprocessing
124+ total_frames = len (raw .body .data )
125+ if total_frames < 5 :
126+ print (f"[SKIP SHORT FILE] idx={ idx } | total_frames={ total_frames } | "
127+ f"file={ os .path .basename (pose_path )} " )
128+ return self .__getitem__ ((idx + 1 ) % len (self .records ))
129+
130+ if self .use_reduce_holistic :
131+ raw = reduce_holistic (raw )
132+ raw = pre_process_pose (raw )
133+ pose = raw # Keep in raw scale (no normalization)
134+
135+ # Verify sequence is still valid after preprocessing
57136 total_frames = len (pose .body .data )
58- pivot_frame = random .randint (0 , total_frames - 1 )
137+ if total_frames < 5 :
138+ print (f"[SKIP SHORT CLIP] idx={ idx } | total_frames={ total_frames } " )
139+ return self .__getitem__ ((idx + 1 ) % len (self .records ))
59140
60- input_start = max (0 , pivot_frame - self .num_past_frames )
141+ # Sample time windows intelligently
142+ if total_frames <= (self .num_past_frames + self .num_future_frames + 2 ):
143+ # Short sequence: use centered sampling to maximize data usage
144+ pivot_frame = total_frames // 2
145+ input_start = max (0 , pivot_frame - self .num_past_frames // 2 )
146+ target_end = min (total_frames , input_start + self .num_past_frames + self .num_future_frames )
147+ else :
148+ # Long sequence: random sampling with proper boundaries
149+ pivot_min = self .num_past_frames
150+ pivot_max = total_frames - self .num_future_frames
151+ pivot_frame = random .randint (pivot_min , pivot_max )
152+ input_start = pivot_frame - self .num_past_frames
153+ target_end = pivot_frame + self .num_future_frames
154+
155+ # Extract pose windows
61156 input_pose = pose .body [input_start :pivot_frame ].torch ()
62- target_end = min (total_frames , pivot_frame + self .num_future_frames )
63157 target_pose = pose .body [pivot_frame :target_end ].torch ()
64158
65- return {
66- "input_data" : input_pose .data .zero_filled (),
67- "target_data" : target_pose .data .zero_filled (),
68- "input_mask" : input_pose .data .mask ,
69- "target_mask" : target_pose .data .mask ,
70- "target_length" : torch .tensor ([len (target_pose .data )], dtype = torch .float32 ),
71- "pivot_frame" : pivot_frame ,
72- "target_end" : target_end ,
73- "total_frames" : total_frames ,
74- }
159+ # Debug logging for first few samples
160+ if idx < 3 :
161+ print (f"[DEBUG SPLIT] idx={ idx } | total={ total_frames } | pivot={ pivot_frame } | "
162+ f"input={ input_start } :{ pivot_frame } ({ input_pose .data .shape [0 ]} f) | "
163+ f"target={ pivot_frame } :{ target_end } ({ target_pose .data .shape [0 ]} f) | "
164+ f"file={ os .path .basename (pose_path )} " )
75165
76- def _process_signwriting_image (self , text : str ) -> torch .Tensor :
77- pil_img = signwriting_to_clip_image (text )
78- return self .clip_processor (images = pil_img , return_tensors = "pt" ).pixel_values .squeeze (0 )
166+ # Extract data and masks
167+ input_data = input_pose .data
168+ target_data = target_pose .data
169+ input_mask = input_pose .data .mask
170+ target_mask = target_pose .data .mask
79171
80- def _build_sample_dict (self , info : dict ):
172+ # Process SignWriting image through CLIP
173+ pil_img = signwriting_to_clip_image (rec .get ("text" , "" ))
174+ sign_img = self .clip_processor (images = pil_img , return_tensors = "pt" ).pixel_values .squeeze (0 )
175+
176+ # Build output sample
81177 sample = {
82- "data" : info [ " target_data" ],
178+ "data" : target_data , # Future window (prediction target, unnormalized)
83179 "conditions" : {
84- "input_pose" : info [ " input_data" ],
85- "input_mask" : info [ " input_mask" ],
86- "target_mask" : info [ " target_mask" ],
87- "sign_image" : info [ " sign_img" ],
180+ "input_pose" : input_data , # Past window (conditioning, unnormalized)
181+ "input_mask" : input_mask , # Validity mask for past frames
182+ "target_mask" : target_mask , # Validity mask for future frames
183+ "sign_image" : sign_img , # CLIP-processed SignWriting [3, H, W]
88184 },
89- "id" : info ["rec" ].get ("id" , os .path .basename (info ["rec" ]["pose" ])),
90- "length_target" : info ["target_length" ],
185+ "id" : rec .get ("id" , os .path .basename (rec ["pose" ])),
91186 }
92187
188+ # Add optional metadata for analysis
93189 if self .with_metadata :
94190 meta = {
95- "total_frames" : info [ " total_frames" ] ,
96- "sample_start" : info [ " pivot_frame" ] ,
97- "sample_end" : info [ "target_end" ] ,
98- "orig_start" : info [ "rec" ]. get ( " start" , 0 ) ,
99- "orig_end" : info [ "rec" ]. get ( " end" , info [ " total_frames" ]) ,
191+ "total_frames" : total_frames ,
192+ "sample_start" : pivot_frame ,
193+ "sample_end" : pivot_frame + len ( target_data ) ,
194+ "orig_start" : start or 0 ,
195+ "orig_end" : end or total_frames ,
100196 }
101197 sample ["metadata" ] = {
102- k : torch .tensor ([v ], dtype = torch .long )
198+ k : torch .tensor ([int ( v ) ], dtype = torch .long )
103199 for k , v in meta .items ()
104200 }
105201
106202 return sample
107203
108- def __getitem__ (self , idx ):
109- rec = self .records [idx ]
110- pose_path = os .path .join (self .data_dir , rec ["pose" ])
111-
112- if not os .path .isfile (pose_path ):
113- return self [random .randint (0 , len (self .records ) - 1 )]
114204
115- with open (pose_path , "rb" ) as f :
116- raw = Pose .read (
117- f ,
118- start_time = rec .get ("start" ) or None ,
119- end_time = rec .get ("end" ) or None
120- )
121-
122- pose = normalize_mean_std (raw )
123- window = self ._extract_pose_windows (pose )
124- sign_img = self ._process_signwriting_image (rec .get ("text" , "" ))
125-
126- return self ._build_sample_dict ({
127- ** window ,
128- "sign_img" : sign_img ,
129- "rec" : rec ,
130- })
131-
132- def get_num_workers ():
205+ def get_num_workers () -> int :
133206 """
134- Determine appropriate number of workers based on CPU availability.
207+ Determine appropriate number of DataLoader workers based on CPU availability.
208+
209+ Returns:
210+ 0 if CPU count is unavailable or ≤1, otherwise the CPU count
135211 """
136212 cpu_count = os .cpu_count ()
137213 return 0 if cpu_count is None or cpu_count <= 1 else cpu_count
138214
215+
139216def main ():
140- config = DatasetConfig (
141- data_dir = "/scratch/yayun/pose_data/raw_poses" ,
142- csv_path = "/scratch/yayun/pose_data/data.csv" ,
143- num_past_frames = 40 ,
144- num_future_frames = 20 ,
145- split = 'train'
146- )
217+ """Test dataset loading and print sample batch statistics."""
218+ data_dir = "/home/yayun/data/pose_data"
219+ csv_path = "/home/yayun/data/signwriting-animation/data_fixed.csv"
147220
148221 dataset = DynamicPosePredictionDataset (
149- config = config ,
222+ data_dir = data_dir ,
223+ csv_path = csv_path ,
224+ num_past_frames = 60 ,
225+ num_future_frames = 30 ,
150226 with_metadata = True ,
227+ split = "train" ,
228+ use_reduce_holistic = True ,
151229 )
230+
152231 loader = DataLoader (
153232 dataset ,
154233 batch_size = 4 ,
@@ -158,15 +237,34 @@ def main():
158237 pin_memory = False ,
159238 )
160239
240+ # Load and inspect a batch
161241 batch = next (iter (loader ))
162- print ("Batch:" , batch ["data" ].shape )
163- print ("Input pose:" , batch ["conditions" ]["input_pose" ].shape )
164- print ("Input mask:" , batch ["conditions" ]["input_mask" ].shape )
165- print ("Target mask:" , batch ["conditions" ]["target_mask" ].shape )
166- print ("Sign image:" , batch ["conditions" ]["sign_image" ].shape )
242+ print ("Batch shapes:" )
243+ print (f" Data (target): { batch ['data' ].shape } " )
244+ print (f" Input pose: { batch ['conditions' ]['input_pose' ].shape } " )
245+ print (f" Input mask: { batch ['conditions' ]['input_mask' ].shape } " )
246+ print (f" Target mask: { batch ['conditions' ]['target_mask' ].shape } " )
247+ print (f" Sign image: { batch ['conditions' ]['sign_image' ].shape } " )
248+
249+ # Check data range (should be unnormalized)
250+ data = batch ["data" ]
251+ if hasattr (data , "tensor" ):
252+ data = data .tensor
253+ print ("\n Data statistics (should be in raw range):" )
254+ print (f" Min: { data .min ().item ():.4f} " )
255+ print (f" Max: { data .max ().item ():.4f} " )
256+ print (f" Mean: { data .mean ().item ():.4f} " )
257+ print (f" Std: { data .std ().item ():.4f} " )
258+
259+ if abs (data .mean ().item ()) < 0.1 and abs (data .std ().item () - 1.0 ) < 0.2 :
260+ print (" Warning: Data appears normalized (should be raw)" )
261+ else :
262+ print (" Data is in raw range (correct)" )
263+
167264 if "metadata" in batch :
265+ print ("\n Metadata:" )
168266 for k , v in batch ["metadata" ].items ():
169- print (f"Metadata { k } :" , v .shape )
267+ print (f" { k } : { v .shape } " )
170268
171- # if __name__ == "__main__":
172- # main()
269+ #if __name__ == "__main__":
270+ # main()
0 commit comments