diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 91a1c95a..4c81743f 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -115,6 +115,30 @@ def _csv_tsv_gz_path(path: str) -> str: raise FileNotFoundError(f"Neither path exists: {path} or {alt_path}") + +def _litdata_merge(cache_dir: Path) -> None: + """ + Merges LitData binary writer index files in the given cache directory. + + Args: + cache_dir (Path): The cache directory containing LitData binary writer files. + """ + from litdata.streaming.writer import _INDEX_FILENAME + files = os.listdir(cache_dir) + + # Return if the index already exists + if _INDEX_FILENAME in files: + return + + index_files = [f for f in files if f.endswith(_INDEX_FILENAME)] + + # Return if there are no index files to merge + if len(index_files) == 0: + raise ValueError("There are zero samples in the dataset, please check the task and processors.") + + BinaryWriter(cache_dir=str(cache_dir), chunk_bytes="64MB").merge(num_workers=len(index_files)) + + class _ProgressContext: def __init__(self, queue: multiprocessing.queues.Queue | None, total: int, **kwargs): """ @@ -695,7 +719,7 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> if num_workers == 1: logger.info("Single worker mode, processing sequentially") _task_transform_fn((0, task, patient_ids, global_event_df, output_dir)) - BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) + _litdata_merge(output_dir) return # spwan is required for polars in multiprocessing, see https://docs.pola.rs/user-guide/misc/multiprocessing/#summary @@ -721,7 +745,7 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> while not queue.empty(): progress.update(queue.get()) result.get() # ensure exceptions are raised - BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) + _litdata_merge(output_dir) logger.info(f"Task transformation completed and saved to {output_dir}") except Exception as e: @@ -745,7 +769,7 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) -> if num_workers == 1: logger.info("Single worker mode, processing sequentially") _proc_transform_fn((0, task_df, 0, num_samples, output_dir)) - BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) + _litdata_merge(output_dir) return ctx = multiprocessing.get_context("spawn") @@ -771,7 +795,7 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) -> while not queue.empty(): progress.update(queue.get()) result.get() # ensure exceptions are raised - BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) + _litdata_merge(output_dir) logger.info(f"Processor transformation completed and saved to {output_dir}") except Exception as e: @@ -902,4 +926,4 @@ def _main_guard(self, func_name: str): f"{func_name} method accessed from a non-main process. This may lead to unexpected behavior.\n" + "Consider use __name__ == '__main__' guard when using multiprocessing." ) - exit(1) \ No newline at end of file + exit(1)