Skip to content

Commit b2ccec2

Browse files
Change chunk files to store cumulative record counts
1 parent d73578a commit b2ccec2

File tree

2 files changed

+101
-26
lines changed

2 files changed

+101
-26
lines changed

bio2zarr/vcf.py

Lines changed: 54 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,10 @@ def __init__(self, pcvcf, vcf_field):
528528
self.compressor = pcvcf.compressor
529529
self.num_partitions = pcvcf.num_partitions
530530
self.num_records = pcvcf.num_records
531+
self.partition_record_index = pcvcf.partition_record_index
532+
# A map of partition index to the cumulative number of records
533+
# in chunks
534+
self._chunk_cumulative_records = {}
531535

532536
@staticmethod
533537
def get_path(base_path, vcf_field):
@@ -536,17 +540,29 @@ def get_path(base_path, vcf_field):
536540
return base_path / vcf_field.category / vcf_field.name
537541

538542
def __repr__(self):
539-
return f"PickleChunkedVcfField(path={self.path})"
543+
partition_chunks = [self.num_chunks(j) for j in range(self.num_partitions)]
544+
return f"PickleChunkedVcfField(partition_chunks={partition_chunks}, path={self.path})"
540545

541546
def num_chunks(self, partition_index):
547+
return len(self.chunk_files(partition_index))
548+
549+
def chunk_cumulative_records(self, partition_index):
550+
if partition_index not in self._chunk_cumulative_records:
551+
partition_path = self.path / f"p{partition_index}"
552+
# Let numpy do the string->int parsing
553+
a = np.array(os.listdir(partition_path), dtype=int)
554+
a.sort()
555+
self._chunk_cumulative_records[partition_index] = a
556+
return self._chunk_cumulative_records[partition_index]
557+
558+
def chunk_files(self, partition_index):
542559
partition_path = self.path / f"p{partition_index}"
543-
return len(list(partition_path.iterdir()))
544-
545-
def chunk_path(self, partition_index, chunk_index):
546-
return self.path / f"p{partition_index}" / f"c{chunk_index}"
560+
return [
561+
partition_path / str(n)
562+
for n in self.chunk_cumulative_records(partition_index)
563+
]
547564

548-
def read_chunk(self, partition_index, chunk_index):
549-
path = self.chunk_path(partition_index, chunk_index)
565+
def read_chunk(self, path):
550566
with open(path, "rb") as f:
551567
pkl = self.compressor.decode(f.read())
552568
return pickle.loads(pkl), len(pkl)
@@ -555,8 +571,8 @@ def iter_values_bytes(self):
555571
num_records = 0
556572
bytes_read = 0
557573
for partition_index in range(self.num_partitions):
558-
for chunk_index in range(self.num_chunks(partition_index)):
559-
chunk, chunk_bytes = self.read_chunk(partition_index, chunk_index)
574+
for chunk_path in self.chunk_files(partition_index):
575+
chunk, chunk_bytes = self.read_chunk(chunk_path)
560576
bytes_read += chunk_bytes
561577
for record in chunk:
562578
yield record, bytes_read
@@ -569,13 +585,21 @@ def iter_values_bytes(self):
569585
def iter_values(self, start=None, stop=None):
570586
start = 0 if start is None else start
571587
stop = self.num_records if stop is None else stop
572-
num_records = 0
573-
for partition_index in range(self.num_partitions):
574-
for chunk_index in range(self.num_chunks(partition_index)):
575-
chunk, chunk_bytes = self.read_chunk(partition_index, chunk_index)
588+
start_partition = (
589+
np.searchsorted(self.partition_record_index, start, side="right") - 1
590+
)
591+
num_records = self.partition_record_index[start_partition]
592+
assert num_records <= start
593+
for partition_index in range(start_partition, self.num_partitions):
594+
# TODO use the offsets from the partition chunk counts to seek to
595+
# the first chunk
596+
for chunk_path in self.chunk_files(partition_index):
597+
chunk, _ = self.read_chunk(chunk_path)
576598
for record in chunk:
577599
if start <= num_records < stop:
578600
yield record
601+
if num_records >= stop:
602+
return
579603
num_records += 1
580604

581605
# Note: this involves some computation so should arguably be a method,
@@ -627,13 +651,15 @@ class PcvcfFieldWriter:
627651
buff: list = dataclasses.field(default_factory=list)
628652
buffered_bytes: int = 0
629653
chunk_index: int = 0
654+
num_records: int = 0
630655

631656
def append(self, val):
632657
val = self.transformer.transform_and_update_bounds(val)
633658
assert val is None or isinstance(val, np.ndarray)
634659
self.buff.append(val)
635660
val_bytes = sys.getsizeof(val)
636661
self.buffered_bytes += val_bytes
662+
self.num_records += 1
637663
if self.buffered_bytes >= self.max_buffered_bytes:
638664
logger.debug(
639665
f"Flush {self.path} buffered={self.buffered_bytes} max={self.max_buffered_bytes}"
@@ -644,7 +670,7 @@ def append(self, val):
644670
self.chunk_index += 1
645671

646672
def write_chunk(self):
647-
path = self.path / f"c{self.chunk_index}"
673+
path = self.path / f"{self.num_records}"
648674
logger.debug(f"Start write: {path}")
649675
pkl = pickle.dumps(self.buff)
650676
compressed = self.compressor.encode(pkl)
@@ -667,7 +693,7 @@ def flush(self):
667693

668694
class PcvcfPartitionWriter(contextlib.AbstractContextManager):
669695
"""
670-
Writes the data for a PickleChunkedVcf for a given partition.
696+
Writes the data for a PickleChunkedVcf partition.
671697
"""
672698

673699
def __init__(
@@ -724,11 +750,21 @@ def __init__(self, path, metadata, vcf_header):
724750
self.metadata = metadata
725751
self.vcf_header = vcf_header
726752
self.compressor = self.DEFAULT_COMPRESSOR
727-
728753
self.columns = {}
754+
partition_num_records = [
755+
partition.num_records for partition in self.metadata.partitions
756+
]
757+
# Allow us to find which partition a given record is in
758+
self.partition_record_index = np.cumsum([0] + partition_num_records)
729759
for field in self.metadata.fields:
730760
self.columns[field.full_name] = PickleChunkedVcfField(self, field)
731761

762+
def __repr__(self):
763+
return (
764+
f"PickleChunkedVcf(fields={len(self)}, partitions={self.num_partitions}, "
765+
f"records={self.num_records}, path={self.path})"
766+
)
767+
732768
def __getitem__(self, key):
733769
return self.columns[key]
734770

@@ -931,7 +967,6 @@ def convert(
931967
json.dump(vcf_metadata.asdict(), f, indent=4)
932968
with open(out_path / "header.txt", "w") as f:
933969
f.write(header)
934-
return pcvcf
935970

936971

937972
def explode(
@@ -946,13 +981,14 @@ def explode(
946981
if out_path.exists():
947982
shutil.rmtree(out_path)
948983

949-
return PickleChunkedVcf.convert(
984+
PickleChunkedVcf.convert(
950985
vcfs,
951986
out_path,
952987
column_chunk_size=column_chunk_size,
953988
worker_processes=worker_processes,
954989
show_progress=show_progress,
955990
)
991+
return PickleChunkedVcf.load(out_path)
956992

957993

958994
def inspect(if_path):

tests/test_pcvcf.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,6 @@ def test_POS(self, pcvcf):
5656
[111, 112, 14370, 17330, 1110696, 1230237, 1234567, 1235237, 10],
5757
)
5858

59-
def test_POS_slice(self, pcvcf):
60-
col = pcvcf["POS"]
61-
v = [row[0] for row in col.values]
62-
start = 1
63-
stop = 6
64-
s = [row[0] for row in col.iter_values(start, stop)]
65-
assert v[start:stop] == s
66-
6759
def test_REF(self, pcvcf):
6860
ref = ["A", "A", "G", "T", "A", "T", "G", "T", "AC"]
6961
assert pcvcf["REF"].values == ref
@@ -156,3 +148,50 @@ def test_format_string2(self, pcvcf):
156148
non_missing = [v for v in pcvcf["FORMAT/FS2"].values if v is not None]
157149
nt.assert_array_equal(non_missing[0], [["bc", "op"], [".", "op"]])
158150
nt.assert_array_equal(non_missing[1], [["bc", "."], [".", "."]])
151+
152+
153+
class TestSlicing:
154+
data_path = "tests/data/vcf/multi_contig.vcf.gz"
155+
156+
@pytest.fixture(scope="class")
157+
def pcvcf(self, tmp_path_factory):
158+
out = tmp_path_factory.mktemp("data") / "example.exploded"
159+
return vcf.explode([self.data_path], out, column_chunk_size=0.0125)
160+
161+
def test_repr(self, pcvcf):
162+
assert repr(pcvcf).startswith(
163+
"PickleChunkedVcf(fields=7, partitions=5, records=4665, path="
164+
)
165+
166+
def test_partition_record_index(self, pcvcf):
167+
nt.assert_array_equal(
168+
pcvcf.partition_record_index, [0, 933, 1866, 2799, 3732, 4665]
169+
)
170+
171+
def test_pos_chunk_records(self, pcvcf):
172+
pos = pcvcf["POS"]
173+
for j in range(pos.num_partitions):
174+
a = pos.chunk_cumulative_records(j)
175+
nt.assert_array_equal(a, [118, 236, 354, 472, 590, 708, 826, 933])
176+
177+
@pytest.mark.parametrize(
178+
["start", "stop"],
179+
[
180+
(0, 1),
181+
(0, 4665),
182+
(100, 200),
183+
(118, 237),
184+
(710, 850),
185+
(931, 1000),
186+
(1865, 1867),
187+
(1866, 2791),
188+
(2732, 3200),
189+
(4664, 4665),
190+
],
191+
)
192+
def test_slice(self, pcvcf, start, stop):
193+
# TODO put in the actual values here, 5 copies of 0-933
194+
col = pcvcf["POS"]
195+
pos = np.array(col.values)
196+
pos_slice = np.array(list(col.iter_values(start, stop)))
197+
nt.assert_array_equal(pos[start:stop], pos_slice)

0 commit comments

Comments
 (0)