@@ -528,6 +528,10 @@ def __init__(self, pcvcf, vcf_field):
528
528
self .compressor = pcvcf .compressor
529
529
self .num_partitions = pcvcf .num_partitions
530
530
self .num_records = pcvcf .num_records
531
+ self .partition_record_index = pcvcf .partition_record_index
532
+ # A map of partition index to the cumulative number of records
533
+ # in chunks
534
+ self ._chunk_cumulative_records = {}
531
535
532
536
@staticmethod
533
537
def get_path (base_path , vcf_field ):
@@ -536,17 +540,29 @@ def get_path(base_path, vcf_field):
536
540
return base_path / vcf_field .category / vcf_field .name
537
541
538
542
def __repr__ (self ):
539
- return f"PickleChunkedVcfField(path={ self .path } )"
543
+ partition_chunks = [self .num_chunks (j ) for j in range (self .num_partitions )]
544
+ return f"PickleChunkedVcfField(partition_chunks={ partition_chunks } , path={ self .path } )"
540
545
541
546
def num_chunks (self , partition_index ):
547
+ return len (self .chunk_files (partition_index ))
548
+
549
+ def chunk_cumulative_records (self , partition_index ):
550
+ if partition_index not in self ._chunk_cumulative_records :
551
+ partition_path = self .path / f"p{ partition_index } "
552
+ # Let numpy do the string->int parsing
553
+ a = np .array (os .listdir (partition_path ), dtype = int )
554
+ a .sort ()
555
+ self ._chunk_cumulative_records [partition_index ] = a
556
+ return self ._chunk_cumulative_records [partition_index ]
557
+
558
+ def chunk_files (self , partition_index ):
542
559
partition_path = self .path / f"p{ partition_index } "
543
- return len ( list ( partition_path . iterdir ()))
544
-
545
- def chunk_path ( self , partition_index , chunk_index ):
546
- return self . path / f"p { partition_index } " / f"c { chunk_index } "
560
+ return [
561
+ partition_path / str ( n )
562
+ for n in self . chunk_cumulative_records ( partition_index )
563
+ ]
547
564
548
- def read_chunk (self , partition_index , chunk_index ):
549
- path = self .chunk_path (partition_index , chunk_index )
565
+ def read_chunk (self , path ):
550
566
with open (path , "rb" ) as f :
551
567
pkl = self .compressor .decode (f .read ())
552
568
return pickle .loads (pkl ), len (pkl )
@@ -555,8 +571,8 @@ def iter_values_bytes(self):
555
571
num_records = 0
556
572
bytes_read = 0
557
573
for partition_index in range (self .num_partitions ):
558
- for chunk_index in range ( self .num_chunks (partition_index ) ):
559
- chunk , chunk_bytes = self .read_chunk (partition_index , chunk_index )
574
+ for chunk_path in self .chunk_files (partition_index ):
575
+ chunk , chunk_bytes = self .read_chunk (chunk_path )
560
576
bytes_read += chunk_bytes
561
577
for record in chunk :
562
578
yield record , bytes_read
@@ -569,13 +585,21 @@ def iter_values_bytes(self):
569
585
def iter_values (self , start = None , stop = None ):
570
586
start = 0 if start is None else start
571
587
stop = self .num_records if stop is None else stop
572
- num_records = 0
573
- for partition_index in range (self .num_partitions ):
574
- for chunk_index in range (self .num_chunks (partition_index )):
575
- chunk , chunk_bytes = self .read_chunk (partition_index , chunk_index )
588
+ start_partition = (
589
+ np .searchsorted (self .partition_record_index , start , side = "right" ) - 1
590
+ )
591
+ num_records = self .partition_record_index [start_partition ]
592
+ assert num_records <= start
593
+ for partition_index in range (start_partition , self .num_partitions ):
594
+ # TODO use the offsets from the partition chunk counts to seek to
595
+ # the first chunk
596
+ for chunk_path in self .chunk_files (partition_index ):
597
+ chunk , _ = self .read_chunk (chunk_path )
576
598
for record in chunk :
577
599
if start <= num_records < stop :
578
600
yield record
601
+ if num_records >= stop :
602
+ return
579
603
num_records += 1
580
604
581
605
# Note: this involves some computation so should arguably be a method,
@@ -627,13 +651,15 @@ class PcvcfFieldWriter:
627
651
buff : list = dataclasses .field (default_factory = list )
628
652
buffered_bytes : int = 0
629
653
chunk_index : int = 0
654
+ num_records : int = 0
630
655
631
656
def append (self , val ):
632
657
val = self .transformer .transform_and_update_bounds (val )
633
658
assert val is None or isinstance (val , np .ndarray )
634
659
self .buff .append (val )
635
660
val_bytes = sys .getsizeof (val )
636
661
self .buffered_bytes += val_bytes
662
+ self .num_records += 1
637
663
if self .buffered_bytes >= self .max_buffered_bytes :
638
664
logger .debug (
639
665
f"Flush { self .path } buffered={ self .buffered_bytes } max={ self .max_buffered_bytes } "
@@ -644,7 +670,7 @@ def append(self, val):
644
670
self .chunk_index += 1
645
671
646
672
def write_chunk (self ):
647
- path = self .path / f"c { self .chunk_index } "
673
+ path = self .path / f"{ self .num_records } "
648
674
logger .debug (f"Start write: { path } " )
649
675
pkl = pickle .dumps (self .buff )
650
676
compressed = self .compressor .encode (pkl )
@@ -667,7 +693,7 @@ def flush(self):
667
693
668
694
class PcvcfPartitionWriter (contextlib .AbstractContextManager ):
669
695
"""
670
- Writes the data for a PickleChunkedVcf for a given partition.
696
+ Writes the data for a PickleChunkedVcf partition.
671
697
"""
672
698
673
699
def __init__ (
@@ -724,11 +750,21 @@ def __init__(self, path, metadata, vcf_header):
724
750
self .metadata = metadata
725
751
self .vcf_header = vcf_header
726
752
self .compressor = self .DEFAULT_COMPRESSOR
727
-
728
753
self .columns = {}
754
+ partition_num_records = [
755
+ partition .num_records for partition in self .metadata .partitions
756
+ ]
757
+ # Allow us to find which partition a given record is in
758
+ self .partition_record_index = np .cumsum ([0 ] + partition_num_records )
729
759
for field in self .metadata .fields :
730
760
self .columns [field .full_name ] = PickleChunkedVcfField (self , field )
731
761
762
+ def __repr__ (self ):
763
+ return (
764
+ f"PickleChunkedVcf(fields={ len (self )} , partitions={ self .num_partitions } , "
765
+ f"records={ self .num_records } , path={ self .path } )"
766
+ )
767
+
732
768
def __getitem__ (self , key ):
733
769
return self .columns [key ]
734
770
@@ -931,7 +967,6 @@ def convert(
931
967
json .dump (vcf_metadata .asdict (), f , indent = 4 )
932
968
with open (out_path / "header.txt" , "w" ) as f :
933
969
f .write (header )
934
- return pcvcf
935
970
936
971
937
972
def explode (
@@ -946,13 +981,14 @@ def explode(
946
981
if out_path .exists ():
947
982
shutil .rmtree (out_path )
948
983
949
- return PickleChunkedVcf .convert (
984
+ PickleChunkedVcf .convert (
950
985
vcfs ,
951
986
out_path ,
952
987
column_chunk_size = column_chunk_size ,
953
988
worker_processes = worker_processes ,
954
989
show_progress = show_progress ,
955
990
)
991
+ return PickleChunkedVcf .load (out_path )
956
992
957
993
958
994
def inspect (if_path ):
0 commit comments