11
11
import math
12
12
import tempfile
13
13
import contextlib
14
- from typing import Any
14
+ from typing import Any , List
15
15
16
16
import humanfriendly
17
17
import cyvcf2
@@ -529,38 +529,49 @@ def __init__(self, pcvcf, vcf_field):
529
529
self .num_partitions = pcvcf .num_partitions
530
530
self .num_records = pcvcf .num_records
531
531
self .partition_record_index = pcvcf .partition_record_index
532
- # A map of partition index to the cumulative number of records
533
- # in chunks
534
- self ._chunk_cumulative_records = {}
532
+ # A map of partition id to the cumulative number of records
533
+ # in chunks within that partition
534
+ self ._chunk_record_index = {}
535
535
536
536
@staticmethod
537
537
def get_path (base_path , vcf_field ):
538
538
if vcf_field .category == "fixed" :
539
539
return base_path / vcf_field .name
540
540
return base_path / vcf_field .category / vcf_field .name
541
541
542
+ def partition_path (self , partition_id ):
543
+ return self .path / f"p{ partition_id } "
544
+
542
545
def __repr__ (self ):
543
546
partition_chunks = [self .num_chunks (j ) for j in range (self .num_partitions )]
544
- return f"PickleChunkedVcfField(partition_chunks={ partition_chunks } , path={ self .path } )"
545
-
546
- def num_chunks (self , partition_index ):
547
- return len (self .chunk_files (partition_index ))
548
-
549
- def chunk_cumulative_records (self , partition_index ):
550
- if partition_index not in self ._chunk_cumulative_records :
551
- partition_path = self .path / f"p{ partition_index } "
552
- # Let numpy do the string->int parsing
553
- a = np .array (os .listdir (partition_path ), dtype = int )
554
- a .sort ()
555
- self ._chunk_cumulative_records [partition_index ] = a
556
- return self ._chunk_cumulative_records [partition_index ]
557
-
558
- def chunk_files (self , partition_index ):
559
- partition_path = self .path / f"p{ partition_index } "
560
- return [
561
- partition_path / str (n )
562
- for n in self .chunk_cumulative_records (partition_index )
563
- ]
547
+ return (
548
+ f"PickleChunkedVcfField(partition_chunks={ partition_chunks } , "
549
+ f"path={ self .path } )"
550
+ )
551
+
552
+ def num_chunks (self , partition_id ):
553
+ return len (self .chunk_cumulative_records (partition_id ))
554
+
555
+ def chunk_record_index (self , partition_id ):
556
+ if partition_id not in self ._chunk_record_index :
557
+ index_path = self .partition_path (partition_id ) / "chunk_index.pkl"
558
+ with open (index_path , "rb" ) as f :
559
+ a = pickle .load (f )
560
+ assert len (a ) > 1
561
+ assert a [0 ] == 0
562
+ self ._chunk_record_index [partition_id ] = a
563
+ return self ._chunk_record_index [partition_id ]
564
+
565
+ def chunk_cumulative_records (self , partition_id ):
566
+ return self .chunk_record_index (partition_id )[1 :]
567
+
568
+ def chunk_num_records (self , partition_id ):
569
+ return np .diff (self .chunk_cumulative_records (partition_id ))
570
+
571
+ def chunk_files (self , partition_id , start = 0 ):
572
+ partition_path = self .partition_path (partition_id )
573
+ for n in self .chunk_cumulative_records (partition_id )[start :]:
574
+ yield partition_path / f"{ n } .pkl"
564
575
565
576
def read_chunk (self , path ):
566
577
with open (path , "rb" ) as f :
@@ -570,8 +581,8 @@ def read_chunk(self, path):
570
581
def iter_values_bytes (self ):
571
582
num_records = 0
572
583
bytes_read = 0
573
- for partition_index in range (self .num_partitions ):
574
- for chunk_path in self .chunk_files (partition_index ):
584
+ for partition_id in range (self .num_partitions ):
585
+ for chunk_path in self .chunk_files (partition_id ):
575
586
chunk , chunk_bytes = self .read_chunk (chunk_path )
576
587
bytes_read += chunk_bytes
577
588
for record in chunk :
@@ -588,19 +599,38 @@ def iter_values(self, start=None, stop=None):
588
599
start_partition = (
589
600
np .searchsorted (self .partition_record_index , start , side = "right" ) - 1
590
601
)
591
- num_records = self .partition_record_index [start_partition ]
592
- assert num_records <= start
593
- for partition_index in range (start_partition , self .num_partitions ):
594
- # TODO use the offsets from the partition chunk counts to seek to
595
- # the first chunk
596
- for chunk_path in self .chunk_files (partition_index ):
602
+ offset = self .partition_record_index [start_partition ]
603
+ assert offset <= start
604
+ chunk_offset = start - offset
605
+
606
+ chunk_record_index = self .chunk_record_index (start_partition )
607
+ start_chunk = (
608
+ np .searchsorted (chunk_record_index , chunk_offset , side = "right" ) - 1
609
+ )
610
+ record_id = offset + chunk_record_index [start_chunk ]
611
+ assert record_id <= start
612
+ logger .debug (
613
+ f"{ self .vcf_field .full_name } slice [{ start } :{ stop } ]:"
614
+ f"p_start={ start_partition } , c_start={ start_chunk } , r_start={ record_id } "
615
+ )
616
+
617
+ for chunk_path in self .chunk_files (start_partition , start_chunk ):
618
+ chunk , _ = self .read_chunk (chunk_path )
619
+ for record in chunk :
620
+ if record_id == stop :
621
+ return
622
+ if record_id >= start :
623
+ yield record
624
+ record_id += 1
625
+ assert record_id > start
626
+ for partition_id in range (start_partition + 1 , self .num_partitions ):
627
+ for chunk_path in self .chunk_files (partition_id ):
597
628
chunk , _ = self .read_chunk (chunk_path )
598
629
for record in chunk :
599
- if start <= num_records < stop :
600
- yield record
601
- if num_records >= stop :
630
+ if record_id == stop :
602
631
return
603
- num_records += 1
632
+ yield record
633
+ record_id += 1
604
634
605
635
# Note: this involves some computation so should arguably be a method,
606
636
# but making a property for consistency with xarray etc
@@ -648,9 +678,9 @@ class PcvcfFieldWriter:
648
678
transformer : VcfValueTransformer
649
679
compressor : Any
650
680
max_buffered_bytes : int
651
- buff : list = dataclasses .field (default_factory = list )
681
+ buff : List [ Any ] = dataclasses .field (default_factory = list )
652
682
buffered_bytes : int = 0
653
- chunk_index : int = 0
683
+ chunk_index : List [ int ] = dataclasses . field ( default_factory = lambda : [ 0 ])
654
684
num_records : int = 0
655
685
656
686
def append (self , val ):
@@ -662,15 +692,17 @@ def append(self, val):
662
692
self .num_records += 1
663
693
if self .buffered_bytes >= self .max_buffered_bytes :
664
694
logger .debug (
665
- f"Flush { self .path } buffered={ self .buffered_bytes } max={ self .max_buffered_bytes } "
695
+ f"Flush { self .path } buffered={ self .buffered_bytes } "
696
+ f"max={ self .max_buffered_bytes } "
666
697
)
667
698
self .write_chunk ()
668
699
self .buff .clear ()
669
700
self .buffered_bytes = 0
670
- self .chunk_index += 1
671
701
672
702
def write_chunk (self ):
673
- path = self .path / f"{ self .num_records } "
703
+ # Update index
704
+ self .chunk_index .append (self .num_records )
705
+ path = self .path / f"{ self .num_records } .pkl"
674
706
logger .debug (f"Start write: { path } " )
675
707
pkl = pickle .dumps (self .buff )
676
708
compressed = self .compressor .encode (pkl )
@@ -689,6 +721,9 @@ def flush(self):
689
721
)
690
722
if len (self .buff ) > 0 :
691
723
self .write_chunk ()
724
+ with open (self .path / "chunk_index.pkl" , "wb" ) as f :
725
+ a = np .array (self .chunk_index , dtype = int )
726
+ pickle .dump (a , f )
692
727
693
728
694
729
class PcvcfPartitionWriter (contextlib .AbstractContextManager ):
0 commit comments