Skip to content

Commit d73578a

Browse files
Refactored the explode write-path
1 parent caa3398 commit d73578a

File tree

2 files changed

+101
-74
lines changed

2 files changed

+101
-74
lines changed

bio2zarr/vcf.py

Lines changed: 93 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,8 @@ def scan_vcf(path, target_num_partitions):
238238

239239
regions = indexed_vcf.partition_into_regions(num_parts=target_num_partitions)
240240
logger.info(
241-
f"Split {path} into {len(regions)} regions (target={target_num_partitions})")
241+
f"Split {path} into {len(regions)} regions (target={target_num_partitions})"
242+
)
242243
for region in regions:
243244
metadata.partitions.append(
244245
VcfPartition(
@@ -521,51 +522,29 @@ def transform(self, vcf_value):
521522

522523

523524
class PickleChunkedVcfField:
524-
def __init__(self, vcf_field, base_path):
525+
def __init__(self, pcvcf, vcf_field):
525526
self.vcf_field = vcf_field
526-
if vcf_field.category == "fixed":
527-
self.path = base_path / vcf_field.name
528-
else:
529-
self.path = base_path / vcf_field.category / vcf_field.name
527+
self.path = self.get_path(pcvcf.path, vcf_field)
528+
self.compressor = pcvcf.compressor
529+
self.num_partitions = pcvcf.num_partitions
530+
self.num_records = pcvcf.num_records
530531

531-
# TODO Check if other compressors would give reasonable compression
532-
# with significantly faster times
533-
self.compressor = numcodecs.Blosc(cname="zstd", clevel=7)
534-
# TODO have a clearer way of defining this state between
535-
# read and write mode.
536-
self.num_partitions = None
537-
self.num_records = None
538-
self.partition_num_chunks = {}
532+
@staticmethod
533+
def get_path(base_path, vcf_field):
534+
if vcf_field.category == "fixed":
535+
return base_path / vcf_field.name
536+
return base_path / vcf_field.category / vcf_field.name
539537

540538
def __repr__(self):
541539
return f"PickleChunkedVcfField(path={self.path})"
542540

543541
def num_chunks(self, partition_index):
544-
if partition_index not in self.partition_num_chunks:
545-
partition_path = self.path / f"p{partition_index}"
546-
n = len(list(partition_path.iterdir()))
547-
self.partition_num_chunks[partition_index] = n
548-
return self.partition_num_chunks[partition_index]
542+
partition_path = self.path / f"p{partition_index}"
543+
return len(list(partition_path.iterdir()))
549544

550545
def chunk_path(self, partition_index, chunk_index):
551546
return self.path / f"p{partition_index}" / f"c{chunk_index}"
552547

553-
def write_chunk(self, partition_index, chunk_index, data):
554-
path = self.chunk_path(partition_index, chunk_index)
555-
logger.debug(f"Start write: {path}")
556-
pkl = pickle.dumps(data)
557-
# NOTE assuming that reusing the same compressor instance
558-
# from multiple threads is OK!
559-
compressed = self.compressor.encode(pkl)
560-
with open(path, "wb") as f:
561-
f.write(compressed)
562-
563-
# Update the summary
564-
self.vcf_field.summary.num_chunks += 1
565-
self.vcf_field.summary.compressed_size += len(compressed)
566-
self.vcf_field.summary.uncompressed_size += len(pkl)
567-
logger.debug(f"Finish write: {path}")
568-
569548
def read_chunk(self, partition_index, chunk_index):
570549
path = self.chunk_path(partition_index, chunk_index)
571550
with open(path, "rb") as f:
@@ -587,6 +566,18 @@ def iter_values_bytes(self):
587566
f"Corruption detected: incorrect number of records in {str(self.path)}."
588567
)
589568

569+
def iter_values(self, start=None, stop=None):
570+
start = 0 if start is None else start
571+
stop = self.num_records if stop is None else stop
572+
num_records = 0
573+
for partition_index in range(self.num_partitions):
574+
for chunk_index in range(self.num_chunks(partition_index)):
575+
chunk, chunk_bytes = self.read_chunk(partition_index, chunk_index)
576+
for record in chunk:
577+
if start <= num_records < stop:
578+
yield record
579+
num_records += 1
580+
590581
# Note: this involves some computation so should arguably be a method,
591582
# but making a property for consistency with xarray etc
592583
@property
@@ -627,91 +618,116 @@ def sanitiser_factory(self, shape):
627618

628619

629620
@dataclasses.dataclass
630-
class FieldBuffer:
631-
field: PickleChunkedVcfField
621+
class PcvcfFieldWriter:
622+
vcf_field: VcfField
623+
path: pathlib.Path
632624
transformer: VcfValueTransformer
625+
compressor: Any
626+
max_buffered_bytes: int
633627
buff: list = dataclasses.field(default_factory=list)
634628
buffered_bytes: int = 0
635629
chunk_index: int = 0
636630

637631
def append(self, val):
632+
val = self.transformer.transform_and_update_bounds(val)
633+
assert val is None or isinstance(val, np.ndarray)
638634
self.buff.append(val)
639635
val_bytes = sys.getsizeof(val)
640636
self.buffered_bytes += val_bytes
637+
if self.buffered_bytes >= self.max_buffered_bytes:
638+
logger.debug(
639+
f"Flush {self.path} buffered={self.buffered_bytes} max={self.max_buffered_bytes}"
640+
)
641+
self.write_chunk()
642+
self.buff.clear()
643+
self.buffered_bytes = 0
644+
self.chunk_index += 1
641645

642-
def reset(self):
643-
self.buff = []
644-
self.buffered_bytes = 0
645-
self.chunk_index += 1
646+
def write_chunk(self):
647+
path = self.path / f"c{self.chunk_index}"
648+
logger.debug(f"Start write: {path}")
649+
pkl = pickle.dumps(self.buff)
650+
compressed = self.compressor.encode(pkl)
651+
with open(path, "wb") as f:
652+
f.write(compressed)
646653

654+
# Update the summary
655+
self.vcf_field.summary.num_chunks += 1
656+
self.vcf_field.summary.compressed_size += len(compressed)
657+
self.vcf_field.summary.uncompressed_size += len(pkl)
658+
logger.debug(f"Finish write: {path}")
659+
660+
def flush(self):
661+
logger.debug(
662+
f"Flush {self.path} records={len(self.buff)} buffered={self.buffered_bytes}"
663+
)
664+
if len(self.buff) > 0:
665+
self.write_chunk()
666+
667+
668+
class PcvcfPartitionWriter(contextlib.AbstractContextManager):
669+
"""
670+
Writes the data for a PickleChunkedVcf for a given partition.
671+
"""
647672

648-
class ColumnWriter(contextlib.AbstractContextManager):
649673
def __init__(
650674
self,
651675
vcf_metadata,
652676
out_path,
653677
partition_index,
678+
compressor,
654679
*,
655680
chunk_size=1,
656681
):
657682
self.partition_index = partition_index
658683
# chunk_size is in megabytes
659-
self.max_buffered_bytes = chunk_size * 2**20
660-
assert self.max_buffered_bytes > 0
684+
max_buffered_bytes = chunk_size * 2**20
685+
assert max_buffered_bytes > 0
661686

662-
self.buffers = {}
687+
self.field_writers = {}
663688
num_samples = len(vcf_metadata.samples)
664689
for vcf_field in vcf_metadata.fields:
665-
field = PickleChunkedVcfField(vcf_field, out_path)
690+
field_path = PickleChunkedVcfField.get_path(out_path, vcf_field)
691+
field_partition_path = field_path / f"p{partition_index}"
666692
transformer = VcfValueTransformer.factory(vcf_field, num_samples)
667-
self.buffers[vcf_field.full_name] = FieldBuffer(field, transformer)
693+
self.field_writers[vcf_field.full_name] = PcvcfFieldWriter(
694+
vcf_field,
695+
field_partition_path,
696+
transformer,
697+
compressor,
698+
max_buffered_bytes,
699+
)
668700

669701
@property
670702
def field_summaries(self):
671703
return {
672-
name: buff.field.vcf_field.summary for name, buff in self.buffers.items()
704+
name: field.vcf_field.summary for name, field in self.field_writers.items()
673705
}
674706

675707
def append(self, name, value):
676-
buff = self.buffers[name]
677-
# print("Append", name, value)
678-
value = buff.transformer.transform_and_update_bounds(value)
679-
assert value is None or isinstance(value, np.ndarray)
680-
buff.append(value)
681-
val_bytes = sys.getsizeof(value)
682-
buff.buffered_bytes += val_bytes
683-
if buff.buffered_bytes >= self.max_buffered_bytes:
684-
self._flush_buffer(name, buff)
685-
686-
def _flush_buffer(self, name, buff):
687-
logger.debug(f"Schedule write {name}:{self.partition_index}.{buff.chunk_index}")
688-
buff.field.write_chunk(
689-
self.partition_index,
690-
buff.chunk_index,
691-
buff.buff,
692-
)
693-
buff.reset()
708+
self.field_writers[name].append(value)
694709

695710
def __exit__(self, exc_type, exc_val, exc_tb):
696711
if exc_type is None:
697-
for name, buff in self.buffers.items():
698-
self._flush_buffer(name, buff)
712+
for field in self.field_writers.values():
713+
field.flush()
699714
return False
700715

701716

702717
class PickleChunkedVcf(collections.abc.Mapping):
718+
# TODO Check if other compressors would give reasonable compression
719+
# with significantly faster times
720+
DEFAULT_COMPRESSOR = numcodecs.Blosc(cname="zstd", clevel=7)
721+
703722
def __init__(self, path, metadata, vcf_header):
704723
self.path = path
705724
self.metadata = metadata
706725
self.vcf_header = vcf_header
726+
self.compressor = self.DEFAULT_COMPRESSOR
707727

708728
self.columns = {}
709729
for field in self.metadata.fields:
710-
self.columns[field.full_name] = PickleChunkedVcfField(field, path)
711-
712-
for col in self.columns.values():
713-
col.num_partitions = self.num_partitions
714-
col.num_records = self.num_records
730+
self.columns[field.full_name] = PickleChunkedVcfField(self, field)
715731

716732
def __getitem__(self, key):
717733
return self.columns[key]
@@ -816,10 +832,13 @@ def convert_partition(
816832
else:
817833
format_fields.append(field)
818834

819-
with ColumnWriter(
835+
compressor = PickleChunkedVcf.DEFAULT_COMPRESSOR
836+
837+
with PcvcfPartitionWriter(
820838
vcf_metadata,
821839
out_path,
822840
partition_index,
841+
compressor,
823842
chunk_size=column_chunk_size,
824843
) as tcw:
825844
with vcf_utils.IndexedVcf(partition.vcf_path) as ivcf:

tests/test_pcvcf.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ def test_POS(self, pcvcf):
5656
[111, 112, 14370, 17330, 1110696, 1230237, 1234567, 1235237, 10],
5757
)
5858

59+
def test_POS_slice(self, pcvcf):
60+
col = pcvcf["POS"]
61+
v = [row[0] for row in col.values]
62+
start = 1
63+
stop = 6
64+
s = [row[0] for row in col.iter_values(start, stop)]
65+
assert v[start:stop] == s
66+
5967
def test_REF(self, pcvcf):
6068
ref = ["A", "A", "G", "T", "A", "T", "G", "T", "AC"]
6169
assert pcvcf["REF"].values == ref

0 commit comments

Comments
 (0)