Skip to content

Commit 470f89c

Browse files
authored
Fix the code will hang at set_task if any of the worker have 0 sample written (#784)
* write empty index file if no sample provided * Fixup * remove test code
1 parent dc69b17 commit 470f89c

File tree

1 file changed

+28
-4
lines changed

1 file changed

+28
-4
lines changed

pyhealth/datasets/base_dataset.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,30 @@ def _csv_tsv_gz_path(path: str) -> str:
115115

116116
raise FileNotFoundError(f"Neither path exists: {path} or {alt_path}")
117117

118+
119+
def _litdata_merge(cache_dir: Path) -> None:
120+
"""
121+
Merges LitData binary writer index files in the given cache directory.
122+
123+
Args:
124+
cache_dir (Path): The cache directory containing LitData binary writer files.
125+
"""
126+
from litdata.streaming.writer import _INDEX_FILENAME
127+
files = os.listdir(cache_dir)
128+
129+
# Return if the index already exists
130+
if _INDEX_FILENAME in files:
131+
return
132+
133+
index_files = [f for f in files if f.endswith(_INDEX_FILENAME)]
134+
135+
# Return if there are no index files to merge
136+
if len(index_files) == 0:
137+
raise ValueError("There are zero samples in the dataset, please check the task and processors.")
138+
139+
BinaryWriter(cache_dir=str(cache_dir), chunk_bytes="64MB").merge(num_workers=len(index_files))
140+
141+
118142
class _ProgressContext:
119143
def __init__(self, queue: multiprocessing.queues.Queue | None, total: int, **kwargs):
120144
"""
@@ -695,7 +719,7 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) ->
695719
if num_workers == 1:
696720
logger.info("Single worker mode, processing sequentially")
697721
_task_transform_fn((0, task, patient_ids, global_event_df, output_dir))
698-
BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers)
722+
_litdata_merge(output_dir)
699723
return
700724

701725
# spwan is required for polars in multiprocessing, see https://docs.pola.rs/user-guide/misc/multiprocessing/#summary
@@ -721,7 +745,7 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) ->
721745
while not queue.empty():
722746
progress.update(queue.get())
723747
result.get() # ensure exceptions are raised
724-
BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers)
748+
_litdata_merge(output_dir)
725749

726750
logger.info(f"Task transformation completed and saved to {output_dir}")
727751
except Exception as e:
@@ -745,7 +769,7 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) ->
745769
if num_workers == 1:
746770
logger.info("Single worker mode, processing sequentially")
747771
_proc_transform_fn((0, task_df, 0, num_samples, output_dir))
748-
BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers)
772+
_litdata_merge(output_dir)
749773
return
750774

751775
ctx = multiprocessing.get_context("spawn")
@@ -771,7 +795,7 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) ->
771795
while not queue.empty():
772796
progress.update(queue.get())
773797
result.get() # ensure exceptions are raised
774-
BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers)
798+
_litdata_merge(output_dir)
775799

776800
logger.info(f"Processor transformation completed and saved to {output_dir}")
777801
except Exception as e:

0 commit comments

Comments
 (0)