Skip to content

Commit f5633de

Browse files
Simplify VcfWriterMetadata and test version mismatch
1 parent 1c47953 commit f5633de

File tree

2 files changed

+31
-29
lines changed

2 files changed

+31
-29
lines changed

bio2zarr/vcf.py

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,10 +1548,8 @@ def parse_max_memory(max_memory):
15481548

15491549
@dataclasses.dataclass
15501550
class VcfZarrPartition:
1551-
start_index: int
1552-
stop_index: int
1553-
start_chunk: int
1554-
stop_chunk: int
1551+
start: int
1552+
stop: int
15551553

15561554
@staticmethod
15571555
def generate_partitions(num_records, chunk_size, num_partitions, max_chunks=None):
@@ -1565,9 +1563,7 @@ def generate_partitions(num_records, chunk_size, num_partitions, max_chunks=None
15651563
stop_chunk = int(chunk_slice[-1]) + 1
15661564
start_index = start_chunk * chunk_size
15671565
stop_index = min(stop_chunk * chunk_size, num_records)
1568-
partitions.append(
1569-
VcfZarrPartition(start_index, stop_index, start_chunk, stop_chunk)
1570-
)
1566+
partitions.append(VcfZarrPartition(start_index, stop_index))
15711567
return partitions
15721568

15731569

@@ -1590,7 +1586,7 @@ def asdict(self):
15901586
def fromdict(d):
15911587
if d["format_version"] != VZW_METADATA_FORMAT_VERSION:
15921588
raise ValueError(
1593-
"VcfZarrWriter format version mismatch: "
1589+
"VcfZarrWriter format version mismatch: "
15941590
f"{d['format_version']} != {VZW_METADATA_FORMAT_VERSION}"
15951591
)
15961592
ret = VcfZarrWriterMetadata(**d)
@@ -1675,7 +1671,7 @@ def init(
16751671
root = zarr.group(store=store)
16761672

16771673
for column in self.schema.columns.values():
1678-
self.init_array(root, column, partitions[-1].stop_index)
1674+
self.init_array(root, column, partitions[-1].stop)
16791675

16801676
logger.info("Writing WIP metadata")
16811677
with open(self.wip_path / "metadata.json", "w") as f:
@@ -1809,13 +1805,11 @@ def encode_array_partition(self, column, partition_index):
18091805
array = self.init_partition_array(partition_index, column.name)
18101806

18111807
partition = self.metadata.partitions[partition_index]
1812-
ba = core.BufferedArray(array, partition.start_index)
1808+
ba = core.BufferedArray(array, partition.start)
18131809
source_col = self.icf.columns[column.vcf_field]
18141810
sanitiser = source_col.sanitiser_factory(ba.buff.shape)
18151811

1816-
for value in source_col.iter_values(
1817-
partition.start_index, partition.stop_index
1818-
):
1812+
for value in source_col.iter_values(partition.start, partition.stop):
18191813
# We write directly into the buffer in the sanitiser function
18201814
# to make it easier to reason about dimension padding
18211815
j = ba.next_buffer_row()
@@ -1831,14 +1825,12 @@ def encode_genotypes_partition(self, partition_index):
18311825
)
18321826

18331827
partition = self.metadata.partitions[partition_index]
1834-
gt = core.BufferedArray(gt_array, partition.start_index)
1835-
gt_mask = core.BufferedArray(gt_mask_array, partition.start_index)
1836-
gt_phased = core.BufferedArray(gt_phased_array, partition.start_index)
1828+
gt = core.BufferedArray(gt_array, partition.start)
1829+
gt_mask = core.BufferedArray(gt_mask_array, partition.start)
1830+
gt_phased = core.BufferedArray(gt_phased_array, partition.start)
18371831

18381832
source_col = self.icf.columns["FORMAT/GT"]
1839-
for value in source_col.iter_values(
1840-
partition.start_index, partition.stop_index
1841-
):
1833+
for value in source_col.iter_values(partition.start, partition.stop):
18421834
j = gt.next_buffer_row()
18431835
sanitise_value_int_2d(gt.buff, j, value[:, :-1])
18441836
j = gt_phased.next_buffer_row()
@@ -1859,13 +1851,13 @@ def encode_alleles_partition(self, partition_index):
18591851
array_name = "variant_allele"
18601852
alleles_array = self.init_partition_array(partition_index, array_name)
18611853
partition = self.metadata.partitions[partition_index]
1862-
alleles = core.BufferedArray(alleles_array, partition.start_index)
1854+
alleles = core.BufferedArray(alleles_array, partition.start)
18631855
ref_col = self.icf.columns["REF"]
18641856
alt_col = self.icf.columns["ALT"]
18651857

18661858
for ref, alt in zip(
1867-
ref_col.iter_values(partition.start_index, partition.stop_index),
1868-
alt_col.iter_values(partition.start_index, partition.stop_index),
1859+
ref_col.iter_values(partition.start, partition.stop),
1860+
alt_col.iter_values(partition.start, partition.stop),
18691861
):
18701862
j = alleles.next_buffer_row()
18711863
alleles.buff[j, :] = STR_FILL
@@ -1879,11 +1871,11 @@ def encode_id_partition(self, partition_index):
18791871
vid_array = self.init_partition_array(partition_index, "variant_id")
18801872
vid_mask_array = self.init_partition_array(partition_index, "variant_id_mask")
18811873
partition = self.metadata.partitions[partition_index]
1882-
vid = core.BufferedArray(vid_array, partition.start_index)
1883-
vid_mask = core.BufferedArray(vid_mask_array, partition.start_index)
1874+
vid = core.BufferedArray(vid_array, partition.start)
1875+
vid_mask = core.BufferedArray(vid_mask_array, partition.start)
18841876
col = self.icf.columns["ID"]
18851877

1886-
for value in col.iter_values(partition.start_index, partition.stop_index):
1878+
for value in col.iter_values(partition.start, partition.stop):
18871879
j = vid.next_buffer_row()
18881880
k = vid_mask.next_buffer_row()
18891881
assert j == k
@@ -1904,10 +1896,10 @@ def encode_filters_partition(self, partition_index):
19041896
array_name = "variant_filter"
19051897
array = self.init_partition_array(partition_index, array_name)
19061898
partition = self.metadata.partitions[partition_index]
1907-
var_filter = core.BufferedArray(array, partition.start_index)
1899+
var_filter = core.BufferedArray(array, partition.start)
19081900

19091901
col = self.icf.columns["FILTERS"]
1910-
for value in col.iter_values(partition.start_index, partition.stop_index):
1902+
for value in col.iter_values(partition.start, partition.stop):
19111903
j = var_filter.next_buffer_row()
19121904
var_filter.buff[j] = False
19131905
for f in value:
@@ -1926,10 +1918,10 @@ def encode_contig_partition(self, partition_index):
19261918
array_name = "variant_contig"
19271919
array = self.init_partition_array(partition_index, array_name)
19281920
partition = self.metadata.partitions[partition_index]
1929-
contig = core.BufferedArray(array, partition.start_index)
1921+
contig = core.BufferedArray(array, partition.start)
19301922
col = self.icf.columns["CHROM"]
19311923

1932-
for value in col.iter_values(partition.start_index, partition.stop_index):
1924+
for value in col.iter_values(partition.start, partition.stop):
19331925
j = contig.next_buffer_row()
19341926
# Note: because we are using the indexes to define the lookups
19351927
# and we always have an index, it seems that we the contig lookup

tests/test_vcf.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,16 @@ def test_exploded_metadata_mismatch(self, tmpdir, icf_path, version):
9898
):
9999
vcf.IcfMetadata.fromdict(d)
100100

101+
@pytest.mark.parametrize("version", ["0.0", "1.0", "xxxxx", 0.1])
102+
def test_encode_metadata_mismatch(self, tmpdir, icf_path, version):
103+
zarr_path = tmpdir / "zarr"
104+
vcf.encode_init(icf_path, zarr_path, 1)
105+
with open(zarr_path / "wip" / "metadata.json") as f:
106+
d = json.load(f)
107+
d["format_version"] = version
108+
with pytest.raises(ValueError, match="VcfZarrWriter format version mismatch"):
109+
vcf.VcfZarrWriterMetadata.fromdict(d)
110+
101111

102112
class TestEncodeDimensionSeparator:
103113
@pytest.mark.parametrize("dimension_separator", [None, "/"])

0 commit comments

Comments
 (0)