Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions pyhealth/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,30 @@ def _csv_tsv_gz_path(path: str) -> str:

raise FileNotFoundError(f"Neither path exists: {path} or {alt_path}")


def _litdata_merge(cache_dir: Path) -> None:
"""
Merges LitData binary writer index files in the given cache directory.

Args:
cache_dir (Path): The cache directory containing LitData binary writer files.
"""
from litdata.streaming.writer import _INDEX_FILENAME
files = os.listdir(cache_dir)

# Return if the index already exists
if _INDEX_FILENAME in files:
return

index_files = [f for f in files if f.endswith(_INDEX_FILENAME)]

# Return if there are no index files to merge
if len(index_files) == 0:
raise ValueError("There are zero samples in the dataset, please check the task and processors.")

BinaryWriter(cache_dir=str(cache_dir), chunk_bytes="64MB").merge(num_workers=len(index_files))


class _ProgressContext:
def __init__(self, queue: multiprocessing.queues.Queue | None, total: int, **kwargs):
"""
Expand Down Expand Up @@ -695,7 +719,7 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) ->
if num_workers == 1:
logger.info("Single worker mode, processing sequentially")
_task_transform_fn((0, task, patient_ids, global_event_df, output_dir))
BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers)
_litdata_merge(output_dir)
return

# spwan is required for polars in multiprocessing, see https://docs.pola.rs/user-guide/misc/multiprocessing/#summary
Expand All @@ -721,7 +745,7 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) ->
while not queue.empty():
progress.update(queue.get())
result.get() # ensure exceptions are raised
BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers)
_litdata_merge(output_dir)

logger.info(f"Task transformation completed and saved to {output_dir}")
except Exception as e:
Expand All @@ -745,7 +769,7 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) ->
if num_workers == 1:
logger.info("Single worker mode, processing sequentially")
_proc_transform_fn((0, task_df, 0, num_samples, output_dir))
BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers)
_litdata_merge(output_dir)
return

ctx = multiprocessing.get_context("spawn")
Expand All @@ -771,7 +795,7 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) ->
while not queue.empty():
progress.update(queue.get())
result.get() # ensure exceptions are raised
BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers)
_litdata_merge(output_dir)

logger.info(f"Processor transformation completed and saved to {output_dir}")
except Exception as e:
Expand Down Expand Up @@ -902,4 +926,4 @@ def _main_guard(self, func_name: str):
f"{func_name} method accessed from a non-main process. This may lead to unexpected behavior.\n"
+ "Consider use __name__ == '__main__' guard when using multiprocessing."
)
exit(1)
exit(1)