Skip to content

Commit 5a602c8

Browse files
Merge pull request #145 from jeromekelleher/better-dencode
Better dencode
2 parents 1c47953 + 3faa3c8 commit 5a602c8

File tree

3 files changed

+194
-54
lines changed

3 files changed

+194
-54
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# 0.0.7 2024-04-30
2+
- Change on-disk format of distributed encode and simplify
3+
- Check for all partitions nominally completed encoding before doing
4+
anything destructive in dencode-finalise
5+
16
# 0.0.6 2024-04-24
27

38
- Only use NOSHUFFLE by default on ``call_genotype`` and bool arrays.

bio2zarr/vcf.py

Lines changed: 64 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -294,17 +294,18 @@ def scan_vcf(path, target_num_partitions):
294294

295295
def check_overlap(partitions):
296296
for i in range(1, len(partitions)):
297-
prev_partition = partitions[i - 1]
298-
current_partition = partitions[i]
299-
if (
300-
prev_partition.region.contig == current_partition.region.contig
301-
and prev_partition.region.end > current_partition.region.start
302-
):
303-
raise ValueError(
304-
f"Multiple VCFs have the region "
305-
f"{prev_partition.region.contig}:{prev_partition.region.start}-"
306-
f"{current_partition.region.end}"
307-
)
297+
prev_region = partitions[i - 1].region
298+
current_region = partitions[i].region
299+
if prev_region.contig == current_region.contig:
300+
if prev_region.end is None:
301+
logger.warning("Cannot check overlaps; issue #146")
302+
continue
303+
if prev_region.end > current_region.start:
304+
raise ValueError(
305+
f"Multiple VCFs have the region "
306+
f"{prev_region.contig}:{prev_region.start}-"
307+
f"{current_region.end}"
308+
)
308309

309310

310311
def scan_vcfs(paths, show_progress, target_num_partitions, worker_processes=1):
@@ -453,7 +454,7 @@ def sanitise_value_float_2d(buff, j, value):
453454

454455
def sanitise_int_array(value, ndmin, dtype):
455456
if isinstance(value, tuple):
456-
value = [VCF_INT_MISSING if x is None else x for x in value] # NEEDS TEST
457+
value = [VCF_INT_MISSING if x is None else x for x in value] # NEEDS TEST
457458
value = np.array(value, ndmin=ndmin, copy=False)
458459
value[value == VCF_INT_MISSING] = -1
459460
value[value == VCF_INT_FILL] = -2
@@ -1548,10 +1549,8 @@ def parse_max_memory(max_memory):
15481549

15491550
@dataclasses.dataclass
15501551
class VcfZarrPartition:
1551-
start_index: int
1552-
stop_index: int
1553-
start_chunk: int
1554-
stop_chunk: int
1552+
start: int
1553+
stop: int
15551554

15561555
@staticmethod
15571556
def generate_partitions(num_records, chunk_size, num_partitions, max_chunks=None):
@@ -1565,9 +1564,7 @@ def generate_partitions(num_records, chunk_size, num_partitions, max_chunks=None
15651564
stop_chunk = int(chunk_slice[-1]) + 1
15661565
start_index = start_chunk * chunk_size
15671566
stop_index = min(stop_chunk * chunk_size, num_records)
1568-
partitions.append(
1569-
VcfZarrPartition(start_index, stop_index, start_chunk, stop_chunk)
1570-
)
1567+
partitions.append(VcfZarrPartition(start_index, stop_index))
15711568
return partitions
15721569

15731570

@@ -1590,7 +1587,7 @@ def asdict(self):
15901587
def fromdict(d):
15911588
if d["format_version"] != VZW_METADATA_FORMAT_VERSION:
15921589
raise ValueError(
1593-
"VcfZarrWriter format version mismatch: "
1590+
"VcfZarrWriter format version mismatch: "
15941591
f"{d['format_version']} != {VZW_METADATA_FORMAT_VERSION}"
15951592
)
15961593
ret = VcfZarrWriterMetadata(**d)
@@ -1675,7 +1672,7 @@ def init(
16751672
root = zarr.group(store=store)
16761673

16771674
for column in self.schema.columns.values():
1678-
self.init_array(root, column, partitions[-1].stop_index)
1675+
self.init_array(root, column, partitions[-1].stop)
16791676

16801677
logger.info("Writing WIP metadata")
16811678
with open(self.wip_path / "metadata.json", "w") as f:
@@ -1762,28 +1759,42 @@ def load_metadata(self):
17621759
def partition_path(self, partition_index):
17631760
return self.partitions_path / f"p{partition_index}"
17641761

1762+
def wip_partition_path(self, partition_index):
1763+
return self.partitions_path / f"wip_p{partition_index}"
1764+
17651765
def wip_partition_array_path(self, partition_index, name):
1766-
return self.partition_path(partition_index) / f"wip_{name}"
1766+
return self.wip_partition_path(partition_index) / name
17671767

17681768
def partition_array_path(self, partition_index, name):
17691769
return self.partition_path(partition_index) / name
17701770

17711771
def encode_partition(self, partition_index):
17721772
self.load_metadata()
1773-
partition_path = self.partition_path(partition_index)
1773+
if partition_index < 0 or partition_index >= self.num_partitions:
1774+
raise ValueError(
1775+
"Partition index must be in the range 0 <= index < num_partitions"
1776+
)
1777+
partition_path = self.wip_partition_path(partition_index)
17741778
partition_path.mkdir(exist_ok=True)
17751779
logger.info(f"Encoding partition {partition_index} to {partition_path}")
17761780

1777-
self.encode_alleles_partition(partition_index)
17781781
self.encode_id_partition(partition_index)
17791782
self.encode_filters_partition(partition_index)
17801783
self.encode_contig_partition(partition_index)
1784+
self.encode_alleles_partition(partition_index)
17811785
for col in self.schema.columns.values():
17821786
if col.vcf_field is not None:
17831787
self.encode_array_partition(col, partition_index)
17841788
if "call_genotype" in self.schema.columns:
17851789
self.encode_genotypes_partition(partition_index)
17861790

1791+
final_path = self.partition_path(partition_index)
1792+
logger.info(f"Finalising {partition_index} at {final_path}")
1793+
if final_path.exists():
1794+
logger.warning("Removing existing partition at {final_path}")
1795+
shutil.rmtree(final_path)
1796+
os.rename(partition_path, final_path)
1797+
17871798
def init_partition_array(self, partition_index, name):
17881799
wip_path = self.wip_partition_array_path(partition_index, name)
17891800
# Create an empty array like the definition
@@ -1795,27 +1806,17 @@ def init_partition_array(self, partition_index, name):
17951806
return array
17961807

17971808
def finalise_partition_array(self, partition_index, name):
1798-
wip_path = self.wip_partition_array_path(partition_index, name)
1799-
final_path = self.partition_array_path(partition_index, name)
1800-
if final_path.exists():
1801-
# NEEDS TEST
1802-
logger.warning(f"Removing existing {final_path}")
1803-
shutil.rmtree(final_path)
1804-
# Atomic swap
1805-
os.rename(wip_path, final_path)
18061809
logger.debug(f"Encoded {name} partition {partition_index}")
18071810

18081811
def encode_array_partition(self, column, partition_index):
18091812
array = self.init_partition_array(partition_index, column.name)
18101813

18111814
partition = self.metadata.partitions[partition_index]
1812-
ba = core.BufferedArray(array, partition.start_index)
1815+
ba = core.BufferedArray(array, partition.start)
18131816
source_col = self.icf.columns[column.vcf_field]
18141817
sanitiser = source_col.sanitiser_factory(ba.buff.shape)
18151818

1816-
for value in source_col.iter_values(
1817-
partition.start_index, partition.stop_index
1818-
):
1819+
for value in source_col.iter_values(partition.start, partition.stop):
18191820
# We write directly into the buffer in the sanitiser function
18201821
# to make it easier to reason about dimension padding
18211822
j = ba.next_buffer_row()
@@ -1831,14 +1832,12 @@ def encode_genotypes_partition(self, partition_index):
18311832
)
18321833

18331834
partition = self.metadata.partitions[partition_index]
1834-
gt = core.BufferedArray(gt_array, partition.start_index)
1835-
gt_mask = core.BufferedArray(gt_mask_array, partition.start_index)
1836-
gt_phased = core.BufferedArray(gt_phased_array, partition.start_index)
1835+
gt = core.BufferedArray(gt_array, partition.start)
1836+
gt_mask = core.BufferedArray(gt_mask_array, partition.start)
1837+
gt_phased = core.BufferedArray(gt_phased_array, partition.start)
18371838

18381839
source_col = self.icf.columns["FORMAT/GT"]
1839-
for value in source_col.iter_values(
1840-
partition.start_index, partition.stop_index
1841-
):
1840+
for value in source_col.iter_values(partition.start, partition.stop):
18421841
j = gt.next_buffer_row()
18431842
sanitise_value_int_2d(gt.buff, j, value[:, :-1])
18441843
j = gt_phased.next_buffer_row()
@@ -1859,13 +1858,13 @@ def encode_alleles_partition(self, partition_index):
18591858
array_name = "variant_allele"
18601859
alleles_array = self.init_partition_array(partition_index, array_name)
18611860
partition = self.metadata.partitions[partition_index]
1862-
alleles = core.BufferedArray(alleles_array, partition.start_index)
1861+
alleles = core.BufferedArray(alleles_array, partition.start)
18631862
ref_col = self.icf.columns["REF"]
18641863
alt_col = self.icf.columns["ALT"]
18651864

18661865
for ref, alt in zip(
1867-
ref_col.iter_values(partition.start_index, partition.stop_index),
1868-
alt_col.iter_values(partition.start_index, partition.stop_index),
1866+
ref_col.iter_values(partition.start, partition.stop),
1867+
alt_col.iter_values(partition.start, partition.stop),
18691868
):
18701869
j = alleles.next_buffer_row()
18711870
alleles.buff[j, :] = STR_FILL
@@ -1879,11 +1878,11 @@ def encode_id_partition(self, partition_index):
18791878
vid_array = self.init_partition_array(partition_index, "variant_id")
18801879
vid_mask_array = self.init_partition_array(partition_index, "variant_id_mask")
18811880
partition = self.metadata.partitions[partition_index]
1882-
vid = core.BufferedArray(vid_array, partition.start_index)
1883-
vid_mask = core.BufferedArray(vid_mask_array, partition.start_index)
1881+
vid = core.BufferedArray(vid_array, partition.start)
1882+
vid_mask = core.BufferedArray(vid_mask_array, partition.start)
18841883
col = self.icf.columns["ID"]
18851884

1886-
for value in col.iter_values(partition.start_index, partition.stop_index):
1885+
for value in col.iter_values(partition.start, partition.stop):
18871886
j = vid.next_buffer_row()
18881887
k = vid_mask.next_buffer_row()
18891888
assert j == k
@@ -1904,10 +1903,10 @@ def encode_filters_partition(self, partition_index):
19041903
array_name = "variant_filter"
19051904
array = self.init_partition_array(partition_index, array_name)
19061905
partition = self.metadata.partitions[partition_index]
1907-
var_filter = core.BufferedArray(array, partition.start_index)
1906+
var_filter = core.BufferedArray(array, partition.start)
19081907

19091908
col = self.icf.columns["FILTERS"]
1910-
for value in col.iter_values(partition.start_index, partition.stop_index):
1909+
for value in col.iter_values(partition.start, partition.stop):
19111910
j = var_filter.next_buffer_row()
19121911
var_filter.buff[j] = False
19131912
for f in value:
@@ -1926,10 +1925,10 @@ def encode_contig_partition(self, partition_index):
19261925
array_name = "variant_contig"
19271926
array = self.init_partition_array(partition_index, array_name)
19281927
partition = self.metadata.partitions[partition_index]
1929-
contig = core.BufferedArray(array, partition.start_index)
1928+
contig = core.BufferedArray(array, partition.start)
19301929
col = self.icf.columns["CHROM"]
19311930

1932-
for value in col.iter_values(partition.start_index, partition.stop_index):
1931+
for value in col.iter_values(partition.start, partition.stop):
19331932
j = contig.next_buffer_row()
19341933
# Note: because we are using the indexes to define the lookups
19351934
# and we always have an index, it seems that we the contig lookup
@@ -1950,7 +1949,7 @@ def finalise_array(self, name):
19501949
if final_path.exists():
19511950
# NEEDS TEST
19521951
raise ValueError(f"Array {name} already exists")
1953-
for partition in range(len(self.metadata.partitions)):
1952+
for partition in range(self.num_partitions):
19541953
# Move all the files in partition dir to dest dir
19551954
src = self.partition_array_path(partition, name)
19561955
if not src.exists():
@@ -1977,6 +1976,15 @@ def finalise_array(self, name):
19771976
def finalise(self, show_progress=False):
19781977
self.load_metadata()
19791978

1979+
logger.info("Scanning {self.num_partitions} partitions")
1980+
missing = []
1981+
# TODO may need a progress bar here
1982+
for partition_id in range(self.num_partitions):
1983+
if not self.partition_path(partition_id).exists():
1984+
missing.append(partition_id)
1985+
if len(missing) > 0:
1986+
raise FileNotFoundError(f"Partitions not encoded: {missing}")
1987+
19801988
progress_config = core.ProgressConfig(
19811989
total=len(self.schema.columns),
19821990
title="Finalise",
@@ -1994,6 +2002,9 @@ def finalise(self, show_progress=False):
19942002
with core.ParallelWorkManager(0, progress_config) as pwm:
19952003
for name in self.schema.columns:
19962004
pwm.submit(self.finalise_array, name)
2005+
logger.debug(f"Removing {self.wip_path}")
2006+
shutil.rmtree(self.wip_path)
2007+
logger.info("Consolidating Zarr metadata")
19972008
zarr.consolidate_metadata(self.path)
19982009

19992010
######################

0 commit comments

Comments
 (0)