Skip to content

Commit f184967

Browse files
Implement slicing based on chunk indexes
1 parent b2ccec2 commit f184967

File tree

2 files changed

+94
-42
lines changed

2 files changed

+94
-42
lines changed

bio2zarr/vcf.py

Lines changed: 76 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import math
1212
import tempfile
1313
import contextlib
14-
from typing import Any
14+
from typing import Any, List
1515

1616
import humanfriendly
1717
import cyvcf2
@@ -529,38 +529,49 @@ def __init__(self, pcvcf, vcf_field):
529529
self.num_partitions = pcvcf.num_partitions
530530
self.num_records = pcvcf.num_records
531531
self.partition_record_index = pcvcf.partition_record_index
532-
# A map of partition index to the cumulative number of records
533-
# in chunks
534-
self._chunk_cumulative_records = {}
532+
# A map of partition id to the cumulative number of records
533+
# in chunks within that partition
534+
self._chunk_record_index = {}
535535

536536
@staticmethod
537537
def get_path(base_path, vcf_field):
538538
if vcf_field.category == "fixed":
539539
return base_path / vcf_field.name
540540
return base_path / vcf_field.category / vcf_field.name
541541

542+
def partition_path(self, partition_id):
543+
return self.path / f"p{partition_id}"
544+
542545
def __repr__(self):
543546
partition_chunks = [self.num_chunks(j) for j in range(self.num_partitions)]
544-
return f"PickleChunkedVcfField(partition_chunks={partition_chunks}, path={self.path})"
545-
546-
def num_chunks(self, partition_index):
547-
return len(self.chunk_files(partition_index))
548-
549-
def chunk_cumulative_records(self, partition_index):
550-
if partition_index not in self._chunk_cumulative_records:
551-
partition_path = self.path / f"p{partition_index}"
552-
# Let numpy do the string->int parsing
553-
a = np.array(os.listdir(partition_path), dtype=int)
554-
a.sort()
555-
self._chunk_cumulative_records[partition_index] = a
556-
return self._chunk_cumulative_records[partition_index]
557-
558-
def chunk_files(self, partition_index):
559-
partition_path = self.path / f"p{partition_index}"
560-
return [
561-
partition_path / str(n)
562-
for n in self.chunk_cumulative_records(partition_index)
563-
]
547+
return (
548+
f"PickleChunkedVcfField(partition_chunks={partition_chunks}, "
549+
f"path={self.path})"
550+
)
551+
552+
def num_chunks(self, partition_id):
553+
return len(self.chunk_cumulative_records(partition_id))
554+
555+
def chunk_record_index(self, partition_id):
556+
if partition_id not in self._chunk_record_index:
557+
index_path = self.partition_path(partition_id) / "chunk_index.pkl"
558+
with open(index_path, "rb") as f:
559+
a = pickle.load(f)
560+
assert len(a) > 1
561+
assert a[0] == 0
562+
self._chunk_record_index[partition_id] = a
563+
return self._chunk_record_index[partition_id]
564+
565+
def chunk_cumulative_records(self, partition_id):
566+
return self.chunk_record_index(partition_id)[1:]
567+
568+
def chunk_num_records(self, partition_id):
569+
return np.diff(self.chunk_cumulative_records(partition_id))
570+
571+
def chunk_files(self, partition_id, start=0):
572+
partition_path = self.partition_path(partition_id)
573+
for n in self.chunk_cumulative_records(partition_id)[start:]:
574+
yield partition_path / f"{n}.pkl"
564575

565576
def read_chunk(self, path):
566577
with open(path, "rb") as f:
@@ -570,8 +581,8 @@ def read_chunk(self, path):
570581
def iter_values_bytes(self):
571582
num_records = 0
572583
bytes_read = 0
573-
for partition_index in range(self.num_partitions):
574-
for chunk_path in self.chunk_files(partition_index):
584+
for partition_id in range(self.num_partitions):
585+
for chunk_path in self.chunk_files(partition_id):
575586
chunk, chunk_bytes = self.read_chunk(chunk_path)
576587
bytes_read += chunk_bytes
577588
for record in chunk:
@@ -588,19 +599,38 @@ def iter_values(self, start=None, stop=None):
588599
start_partition = (
589600
np.searchsorted(self.partition_record_index, start, side="right") - 1
590601
)
591-
num_records = self.partition_record_index[start_partition]
592-
assert num_records <= start
593-
for partition_index in range(start_partition, self.num_partitions):
594-
# TODO use the offsets from the partition chunk counts to seek to
595-
# the first chunk
596-
for chunk_path in self.chunk_files(partition_index):
602+
offset = self.partition_record_index[start_partition]
603+
assert offset <= start
604+
chunk_offset = start - offset
605+
606+
chunk_record_index = self.chunk_record_index(start_partition)
607+
start_chunk = (
608+
np.searchsorted(chunk_record_index, chunk_offset, side="right") - 1
609+
)
610+
record_id = offset + chunk_record_index[start_chunk]
611+
assert record_id <= start
612+
logger.debug(
613+
f"{self.vcf_field.full_name} slice [{start}:{stop}]:"
614+
f"p_start={start_partition}, c_start={start_chunk}, r_start={record_id}"
615+
)
616+
617+
for chunk_path in self.chunk_files(start_partition, start_chunk):
618+
chunk, _ = self.read_chunk(chunk_path)
619+
for record in chunk:
620+
if record_id == stop:
621+
return
622+
if record_id >= start:
623+
yield record
624+
record_id += 1
625+
assert record_id > start
626+
for partition_id in range(start_partition + 1, self.num_partitions):
627+
for chunk_path in self.chunk_files(partition_id):
597628
chunk, _ = self.read_chunk(chunk_path)
598629
for record in chunk:
599-
if start <= num_records < stop:
600-
yield record
601-
if num_records >= stop:
630+
if record_id == stop:
602631
return
603-
num_records += 1
632+
yield record
633+
record_id += 1
604634

605635
# Note: this involves some computation so should arguably be a method,
606636
# but making a property for consistency with xarray etc
@@ -648,9 +678,9 @@ class PcvcfFieldWriter:
648678
transformer: VcfValueTransformer
649679
compressor: Any
650680
max_buffered_bytes: int
651-
buff: list = dataclasses.field(default_factory=list)
681+
buff: List[Any] = dataclasses.field(default_factory=list)
652682
buffered_bytes: int = 0
653-
chunk_index: int = 0
683+
chunk_index: List[int] = dataclasses.field(default_factory=lambda: [0])
654684
num_records: int = 0
655685

656686
def append(self, val):
@@ -662,15 +692,17 @@ def append(self, val):
662692
self.num_records += 1
663693
if self.buffered_bytes >= self.max_buffered_bytes:
664694
logger.debug(
665-
f"Flush {self.path} buffered={self.buffered_bytes} max={self.max_buffered_bytes}"
695+
f"Flush {self.path} buffered={self.buffered_bytes} "
696+
f"max={self.max_buffered_bytes}"
666697
)
667698
self.write_chunk()
668699
self.buff.clear()
669700
self.buffered_bytes = 0
670-
self.chunk_index += 1
671701

672702
def write_chunk(self):
673-
path = self.path / f"{self.num_records}"
703+
# Update index
704+
self.chunk_index.append(self.num_records)
705+
path = self.path / f"{self.num_records}.pkl"
674706
logger.debug(f"Start write: {path}")
675707
pkl = pickle.dumps(self.buff)
676708
compressed = self.compressor.encode(pkl)
@@ -689,6 +721,9 @@ def flush(self):
689721
)
690722
if len(self.buff) > 0:
691723
self.write_chunk()
724+
with open(self.path / "chunk_index.pkl", "wb") as f:
725+
a = np.array(self.chunk_index, dtype=int)
726+
pickle.dump(a, f)
692727

693728

694729
class PcvcfPartitionWriter(contextlib.AbstractContextManager):

tests/test_pcvcf.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,29 +168,46 @@ def test_partition_record_index(self, pcvcf):
168168
pcvcf.partition_record_index, [0, 933, 1866, 2799, 3732, 4665]
169169
)
170170

171+
def test_pos_values(self, pcvcf):
172+
col = pcvcf["POS"]
173+
pos = np.array([v[0] for v in col.values])
174+
# Check the actual values here to make sure other tests make sense
175+
actual = np.hstack([1 + np.arange(933) for _ in range(5)])
176+
nt.assert_array_equal(pos, actual)
177+
171178
def test_pos_chunk_records(self, pcvcf):
172179
pos = pcvcf["POS"]
173180
for j in range(pos.num_partitions):
181+
a = pos.chunk_record_index(j)
182+
nt.assert_array_equal(a, [0, 118, 236, 354, 472, 590, 708, 826, 933])
174183
a = pos.chunk_cumulative_records(j)
175184
nt.assert_array_equal(a, [118, 236, 354, 472, 590, 708, 826, 933])
185+
a = pos.chunk_num_records(j)
186+
nt.assert_array_equal(a, [118, 118, 118, 118, 118, 118, 107])
176187

177188
@pytest.mark.parametrize(
178189
["start", "stop"],
179190
[
180191
(0, 1),
181192
(0, 4665),
182193
(100, 200),
194+
(100, 500),
195+
(100, 1000),
196+
(100, 1500),
197+
(100, 4500),
198+
(2000, 2500),
183199
(118, 237),
184200
(710, 850),
185201
(931, 1000),
186202
(1865, 1867),
187203
(1866, 2791),
188204
(2732, 3200),
205+
(2798, 2799),
206+
(2799, 2800),
189207
(4664, 4665),
190208
],
191209
)
192210
def test_slice(self, pcvcf, start, stop):
193-
# TODO put in the actual values here, 5 copies of 0-933
194211
col = pcvcf["POS"]
195212
pos = np.array(col.values)
196213
pos_slice = np.array(list(col.iter_values(start, stop)))

0 commit comments

Comments
 (0)