Skip to content

Commit c948585

Browse files
[bugfix] [training] fix deadlock in latent datasets and init error in multi-node training (hao-ai-lab#598)
1 parent e939e36 commit c948585

File tree

4 files changed

+189
-132
lines changed

4 files changed

+189
-132
lines changed

fastvideo/v1/dataset/parquet_dataset_iterable_style.py

Lines changed: 113 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@
55
import numpy as np
66
import pyarrow as pa
77
import pyarrow.parquet as pq
8-
import torch
98
import tqdm
109
from torch.utils.data import IterableDataset, get_worker_info
1110
from torchdata.stateful_dataloader import StatefulDataLoader
1211

1312
from fastvideo.v1.dataset.utils import collate_latents_embs_masks
14-
from fastvideo.v1.distributed import (get_sp_world_size, get_world_rank,
15-
get_world_size)
13+
from fastvideo.v1.distributed import (get_sp_world_size, get_world_group,
14+
get_world_rank, get_world_size)
1615
from fastvideo.v1.logger import init_logger
1716

1817
logger = init_logger(__name__)
@@ -157,95 +156,123 @@ def shard_parquet_files_across_sp_groups_and_workers(
157156
# Check if sharding plan already exists
158157
sharding_info_dir = os.path.join(
159158
path, f"sharding_info_{num_sp_groups}_sp_groups_{num_workers}_workers")
160-
if os.path.exists(sharding_info_dir):
161-
logger.info("Sharding plan already exists")
162-
logger.info("Loading sharding plan from %s", sharding_info_dir)
163-
try:
159+
160+
# Only rank 0 handles cache checking and file scanning
161+
if get_world_rank() == 0:
162+
cache_loaded = False
163+
shard_parquet_files = None
164+
shard_total_samples = None
165+
shard_parquet_lengths = None
166+
167+
# First try to load existing sharding plan
168+
if os.path.exists(sharding_info_dir):
169+
logger.info("Loading sharding plan from %s", sharding_info_dir)
170+
try:
171+
with open(
172+
os.path.join(sharding_info_dir,
173+
"shard_parquet_files.pkl"), "rb") as f:
174+
shard_parquet_files = pickle.load(f)
175+
with open(
176+
os.path.join(sharding_info_dir,
177+
"shard_total_samples.pkl"), "rb") as f:
178+
shard_total_samples = pickle.load(f)
179+
with open(
180+
os.path.join(sharding_info_dir,
181+
"shard_parquet_lengths.pkl"), "rb") as f:
182+
shard_parquet_lengths = pickle.load(f)
183+
cache_loaded = True
184+
logger.info("Successfully loaded sharding plan")
185+
except Exception as e:
186+
logger.error("Error loading sharding plan: %s", str(e))
187+
logger.info("Falling back to creating new sharding plan")
188+
cache_loaded = False
189+
190+
# If cache not loaded (either doesn't exist or failed to load), create sharding plan
191+
if not cache_loaded:
192+
logger.info("Creating new sharding plan")
193+
logger.info("Scanning for parquet files in %s", path)
194+
195+
# Find all parquet files
196+
parquet_files = []
197+
198+
for root, _, files in os.walk(path):
199+
for file in files:
200+
if file.endswith('.parquet'):
201+
parquet_files.append(os.path.join(root, file))
202+
203+
if not parquet_files:
204+
raise ValueError("No parquet files found in %s", path)
205+
206+
# Calculate file lengths efficiently using a single pass
207+
logger.info("Calculating file lengths...")
208+
lengths = []
209+
for file in tqdm.tqdm(parquet_files, desc="Reading parquet files"):
210+
lengths.append(pq.ParquetFile(file).metadata.num_rows)
211+
212+
total_samples = sum(lengths)
213+
logger.info("Found %d files with %d total samples",
214+
len(parquet_files), total_samples)
215+
216+
# Sort files by length for better balancing
217+
sorted_indices = np.argsort(lengths)
218+
sorted_files = [parquet_files[i] for i in sorted_indices]
219+
sorted_lengths = [lengths[i] for i in sorted_indices]
220+
221+
# Create shards
222+
num_shards = num_sp_groups * num_workers
223+
shard_parquet_files = [[] for _ in range(num_shards)]
224+
shard_total_samples = [0] * num_shards
225+
shard_parquet_lengths = [{} for _ in range(num_shards)]
226+
227+
# Distribute files to shards using a greedy approach
228+
logger.info("Distributing files to shards...")
229+
for file, length in zip(reversed(sorted_files),
230+
reversed(sorted_lengths),
231+
strict=True):
232+
# Find shard with minimum current length
233+
target_shard = np.argmin(shard_total_samples)
234+
shard_parquet_files[target_shard].append(file)
235+
shard_total_samples[target_shard] += length
236+
shard_parquet_lengths[target_shard][file] = length
237+
#randomize each shard
238+
for shard in shard_parquet_files:
239+
random.seed(seed)
240+
random.shuffle(shard)
241+
242+
# Save the sharding plan
243+
os.makedirs(sharding_info_dir, exist_ok=True)
164244
with open(
165245
os.path.join(sharding_info_dir, "shard_parquet_files.pkl"),
166-
"rb") as f:
167-
shard_parquet_files = pickle.load(f)
246+
"wb") as f:
247+
pickle.dump(shard_parquet_files, f)
168248
with open(
169249
os.path.join(sharding_info_dir, "shard_total_samples.pkl"),
170-
"rb") as f:
171-
shard_total_samples = pickle.load(f)
250+
"wb") as f:
251+
pickle.dump(shard_total_samples, f)
172252
with open(
173253
os.path.join(sharding_info_dir,
174-
"shard_parquet_lengths.pkl"), "rb") as f:
175-
shard_parquet_lengths = pickle.load(f)
176-
return shard_parquet_files, shard_total_samples, shard_parquet_lengths
177-
except Exception as e:
178-
logger.error("Error loading sharding plan: %s", str(e))
179-
logger.info("Falling back to creating new sharding plan")
180-
181-
if get_world_rank() == 0:
182-
logger.info("Scanning for parquet files in %s", path)
183-
184-
# Find all parquet files
185-
parquet_files = []
186-
187-
for root, _, files in os.walk(path):
188-
for file in files:
189-
if file.endswith('.parquet'):
190-
parquet_files.append(os.path.join(root, file))
191-
192-
if not parquet_files:
193-
raise ValueError("No parquet files found in %s", path)
194-
195-
# Calculate file lengths efficiently using a single pass
196-
logger.info("Calculating file lengths...")
197-
lengths = []
198-
for file in tqdm.tqdm(parquet_files, desc="Reading parquet files"):
199-
lengths.append(pq.ParquetFile(file).metadata.num_rows)
200-
201-
total_samples = sum(lengths)
202-
logger.info("Found %d files with %d total samples", len(parquet_files),
203-
total_samples)
204-
205-
# Sort files by length for better balancing
206-
sorted_indices = np.argsort(lengths)
207-
sorted_files = [parquet_files[i] for i in sorted_indices]
208-
sorted_lengths = [lengths[i] for i in sorted_indices]
209-
210-
# Create shards
211-
num_shards = num_sp_groups * num_workers
212-
shard_parquet_files = [[] for _ in range(num_shards)]
213-
shard_total_samples = [0] * num_shards
214-
shard_parquet_lengths = [{} for _ in range(num_shards)]
215-
216-
# Distribute files to shards using a greedy approach
217-
logger.info("Distributing files to shards...")
218-
for file, length in zip(reversed(sorted_files),
219-
reversed(sorted_lengths),
220-
strict=True):
221-
# Find shard with minimum current length
222-
target_shard = np.argmin(shard_total_samples)
223-
shard_parquet_files[target_shard].append(file)
224-
shard_total_samples[target_shard] += length
225-
shard_parquet_lengths[target_shard][file] = length
226-
#randomize each shard
227-
for shard in shard_parquet_files:
228-
random.seed(seed)
229-
random.shuffle(shard)
230-
231-
save_dir = os.path.join(
232-
path,
233-
f"sharding_info_{num_sp_groups}_sp_groups_{num_workers}_workers")
234-
os.makedirs(save_dir, exist_ok=True)
235-
with open(os.path.join(save_dir, "shard_parquet_files.pkl"), "wb") as f:
236-
pickle.dump(shard_parquet_files, f)
237-
with open(os.path.join(save_dir, "shard_total_samples.pkl"), "wb") as f:
238-
pickle.dump(shard_total_samples, f)
239-
with open(os.path.join(save_dir, "shard_parquet_lengths.pkl"),
240-
"wb") as f:
241-
pickle.dump(shard_parquet_lengths, f)
242-
logger.info("Saved sharding info to %s", save_dir)
243-
244-
# wait for all ranks to finish
245-
torch.distributed.barrier()
246-
# recursive call
247-
return shard_parquet_files_across_sp_groups_and_workers(
248-
path, num_sp_groups, num_workers, seed)
254+
"shard_parquet_lengths.pkl"), "wb") as f:
255+
pickle.dump(shard_parquet_lengths, f)
256+
logger.info("Saved sharding info to %s", sharding_info_dir)
257+
258+
# Wait for rank 0 to finish creating/loading sharding plan
259+
world_group = get_world_group()
260+
world_group.barrier()
261+
262+
# Now all ranks load the sharding plan (it should exist and be valid now)
263+
logger.info("Loading sharding plan from %s after barrier",
264+
sharding_info_dir)
265+
with open(os.path.join(sharding_info_dir, "shard_parquet_files.pkl"),
266+
"rb") as f:
267+
shard_parquet_files = pickle.load(f)
268+
with open(os.path.join(sharding_info_dir, "shard_total_samples.pkl"),
269+
"rb") as f:
270+
shard_total_samples = pickle.load(f)
271+
with open(os.path.join(sharding_info_dir, "shard_parquet_lengths.pkl"),
272+
"rb") as f:
273+
shard_parquet_lengths = pickle.load(f)
274+
275+
return shard_parquet_files, shard_total_samples, shard_parquet_lengths
249276

250277

251278
def build_parquet_iterable_style_dataloader(

fastvideo/v1/dataset/parquet_dataset_map_style.py

Lines changed: 74 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from torchdata.stateful_dataloader import StatefulDataLoader
1414

1515
from fastvideo.v1.dataset.utils import collate_rows_from_parquet_schema
16-
from fastvideo.v1.distributed import (get_sp_world_size, get_world_rank,
17-
get_world_size)
16+
from fastvideo.v1.distributed import (get_sp_world_size, get_world_group,
17+
get_world_rank, get_world_size)
1818
from fastvideo.v1.logger import init_logger
1919

2020
logger = init_logger(__name__)
@@ -97,48 +97,64 @@ def get_parquet_files_and_length(path: str):
9797
cache_dir = os.path.join(path, "map_style_cache")
9898
cache_file = os.path.join(cache_dir, "file_info.pkl")
9999

100-
if os.path.exists(cache_file):
101-
logger.info("Loading cached file info from %s", cache_file)
102-
try:
103-
with open(cache_file, "rb") as f:
104-
file_names_sorted, lengths_sorted = pickle.load(f)
105-
return file_names_sorted, lengths_sorted
106-
except Exception as e:
107-
logger.error("Error loading cached file info: %s", str(e))
108-
logger.info("Falling back to scanning files")
109-
110-
# If no cache exists or loading failed, scan files
100+
# Only rank 0 checks for cache and scans files if needed
111101
if get_world_rank() == 0:
112-
lengths = []
113-
file_names = []
114-
for root, _, files in os.walk(path):
115-
for file in sorted(files):
116-
if file.endswith('.parquet'):
117-
file_path = os.path.join(root, file)
118-
file_names.append(file_path)
119-
for file_path in tqdm.tqdm(file_names,
120-
desc="Reading parquet files to get lengths"):
121-
num_rows = pq.ParquetFile(file_path).metadata.num_rows
122-
lengths.append(num_rows)
123-
# sort according to file name to ensure all rank has the same order (in case os.walk is not sorted)
124-
file_names_sorted, lengths_sorted = zip(*sorted(zip(file_names,
125-
lengths,
126-
strict=True),
127-
key=lambda x: x[0]),
128-
strict=True)
129-
assert len(
130-
file_names_sorted) != 0, "No parquet files found in the dataset"
131-
132-
os.makedirs(cache_dir, exist_ok=True)
133-
with open(cache_file, "wb") as f:
134-
pickle.dump((file_names_sorted, lengths_sorted), f)
135-
logger.info("Saved file info to %s", cache_file)
136-
137-
# Wait for rank 0 to finish saving
138-
if get_world_size() > 1:
139-
torch.distributed.barrier()
140-
141-
return get_parquet_files_and_length(path)
102+
cache_loaded = False
103+
file_names_sorted = None
104+
lengths_sorted = None
105+
106+
# First try to load existing cache
107+
if os.path.exists(cache_file):
108+
logger.info("Loading cached file info from %s", cache_file)
109+
try:
110+
with open(cache_file, "rb") as f:
111+
file_names_sorted, lengths_sorted = pickle.load(f)
112+
cache_loaded = True
113+
logger.info("Successfully loaded cached file info")
114+
except Exception as e:
115+
logger.error("Error loading cached file info: %s", str(e))
116+
logger.info("Falling back to scanning files")
117+
cache_loaded = False
118+
119+
# If cache not loaded (either doesn't exist or failed to load), scan files
120+
if not cache_loaded:
121+
logger.info("Scanning parquet files to get lengths")
122+
lengths = []
123+
file_names = []
124+
for root, _, files in os.walk(path):
125+
for file in sorted(files):
126+
if file.endswith('.parquet'):
127+
file_path = os.path.join(root, file)
128+
file_names.append(file_path)
129+
for file_path in tqdm.tqdm(
130+
file_names, desc="Reading parquet files to get lengths"):
131+
num_rows = pq.ParquetFile(file_path).metadata.num_rows
132+
lengths.append(num_rows)
133+
# sort according to file name to ensure all rank has the same order
134+
file_names_sorted, lengths_sorted = zip(*sorted(zip(file_names,
135+
lengths,
136+
strict=True),
137+
key=lambda x: x[0]),
138+
strict=True)
139+
assert len(
140+
file_names_sorted) != 0, "No parquet files found in the dataset"
141+
142+
# Save the cache
143+
os.makedirs(cache_dir, exist_ok=True)
144+
with open(cache_file, "wb") as f:
145+
pickle.dump((file_names_sorted, lengths_sorted), f)
146+
logger.info("Saved file info to %s", cache_file)
147+
148+
# Wait for rank 0 to finish creating/loading cache
149+
world_group = get_world_group()
150+
world_group.barrier()
151+
152+
# Now all ranks load the cache (it should exist and be valid now)
153+
logger.info("Loading cached file info from %s after barrier", cache_file)
154+
with open(cache_file, "rb") as f:
155+
file_names_sorted, lengths_sorted = pickle.load(f)
156+
157+
return file_names_sorted, lengths_sorted
142158

143159

144160
def read_row_from_parquet_file(parquet_files: list[str], global_row_idx: int,
@@ -153,24 +169,39 @@ def read_row_from_parquet_file(parquet_files: list[str], global_row_idx: int,
153169
'''
154170
# find the parquet file and local row index
155171
cumulative = 0
172+
file_index = 0
173+
local_row_idx = 0
174+
156175
for file_index in range(len(lengths)):
157176
if cumulative + lengths[file_index] > global_row_idx:
158177
local_row_idx = global_row_idx - cumulative
159178
break
160179
cumulative += lengths[file_index]
180+
else:
181+
# If we reach here, global_row_idx is out of bounds
182+
raise IndexError(
183+
f"global_row_idx {global_row_idx} is out of bounds for dataset")
161184

162185
parquet_file = pq.ParquetFile(parquet_files[file_index])
163186

164187
# Calculate the row group to read into memory and the local idx
165188
# This way we can avoid reading in the entire parquet file
166189
cumulative = 0
190+
row_group_index = 0
191+
local_index = 0
192+
167193
for i in range(parquet_file.num_row_groups):
168194
num_rows = parquet_file.metadata.row_group(i).num_rows
169195
if cumulative + num_rows > local_row_idx:
170196
row_group_index = i
171197
local_index = local_row_idx - cumulative
172198
break
173199
cumulative += num_rows
200+
else:
201+
# If we reach here, local_row_idx is out of bounds for this parquet file
202+
raise IndexError(
203+
f"local_row_idx {local_row_idx} is out of bounds for parquet file {parquet_files[file_index]}"
204+
)
174205

175206
row_group = parquet_file.read_row_group(row_group_index).to_pydict()
176207
row_dict = {k: v[local_index] for k, v in row_group.items()}

0 commit comments

Comments
 (0)