Skip to content

Commit 9d1e35b

Browse files
Simplify distributed encode
Closes #138 Closes #139
1 parent f5633de commit 9d1e35b

File tree

3 files changed

+148
-13
lines changed

3 files changed

+148
-13
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# 0.0.7 2024-04-xx
2+
- Change on-disk format of distributed encode and simplify
3+
- Check for all partitions nominally completed encoding before doing
4+
anything destructive in dencode-finalise
5+
16
# 0.0.6 2024-04-24
27

38
- Only use NOSHUFFLE by default on ``call_genotype`` and bool arrays.

bio2zarr/vcf.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1758,28 +1758,42 @@ def load_metadata(self):
17581758
def partition_path(self, partition_index):
17591759
return self.partitions_path / f"p{partition_index}"
17601760

1761+
def wip_partition_path(self, partition_index):
1762+
return self.partitions_path / f"wip_p{partition_index}"
1763+
17611764
def wip_partition_array_path(self, partition_index, name):
1762-
return self.partition_path(partition_index) / f"wip_{name}"
1765+
return self.wip_partition_path(partition_index) / name
17631766

17641767
def partition_array_path(self, partition_index, name):
17651768
return self.partition_path(partition_index) / name
17661769

17671770
def encode_partition(self, partition_index):
17681771
self.load_metadata()
1769-
partition_path = self.partition_path(partition_index)
1772+
if partition_index < 0 or partition_index >= self.num_partitions:
1773+
raise ValueError(
1774+
"Partition index must be in the range 0 <= index < num_partitions"
1775+
)
1776+
partition_path = self.wip_partition_path(partition_index)
17701777
partition_path.mkdir(exist_ok=True)
17711778
logger.info(f"Encoding partition {partition_index} to {partition_path}")
17721779

1773-
self.encode_alleles_partition(partition_index)
17741780
self.encode_id_partition(partition_index)
17751781
self.encode_filters_partition(partition_index)
17761782
self.encode_contig_partition(partition_index)
1783+
self.encode_alleles_partition(partition_index)
17771784
for col in self.schema.columns.values():
17781785
if col.vcf_field is not None:
17791786
self.encode_array_partition(col, partition_index)
17801787
if "call_genotype" in self.schema.columns:
17811788
self.encode_genotypes_partition(partition_index)
17821789

1790+
final_path = self.partition_path(partition_index)
1791+
logger.info(f"Finalising {partition_index} at {final_path}")
1792+
if final_path.exists():
1793+
logger.warning("Removing existing partition at {final_path}")
1794+
shutil.rmtree(final_path)
1795+
os.rename(partition_path, final_path)
1796+
17831797
def init_partition_array(self, partition_index, name):
17841798
wip_path = self.wip_partition_array_path(partition_index, name)
17851799
# Create an empty array like the definition
@@ -1791,14 +1805,6 @@ def init_partition_array(self, partition_index, name):
17911805
return array
17921806

17931807
def finalise_partition_array(self, partition_index, name):
1794-
wip_path = self.wip_partition_array_path(partition_index, name)
1795-
final_path = self.partition_array_path(partition_index, name)
1796-
if final_path.exists():
1797-
# NEEDS TEST
1798-
logger.warning(f"Removing existing {final_path}")
1799-
shutil.rmtree(final_path)
1800-
# Atomic swap
1801-
os.rename(wip_path, final_path)
18021808
logger.debug(f"Encoded {name} partition {partition_index}")
18031809

18041810
def encode_array_partition(self, column, partition_index):
@@ -1942,7 +1948,7 @@ def finalise_array(self, name):
19421948
if final_path.exists():
19431949
# NEEDS TEST
19441950
raise ValueError(f"Array {name} already exists")
1945-
for partition in range(len(self.metadata.partitions)):
1951+
for partition in range(self.num_partitions):
19461952
# Move all the files in partition dir to dest dir
19471953
src = self.partition_array_path(partition, name)
19481954
if not src.exists():
@@ -1969,6 +1975,13 @@ def finalise_array(self, name):
19691975
def finalise(self, show_progress=False):
19701976
self.load_metadata()
19711977

1978+
missing = []
1979+
for partition_id in range(self.num_partitions):
1980+
if not self.partition_path(partition_id).exists():
1981+
missing.append(partition_id)
1982+
if len(missing) > 0:
1983+
raise FileNotFoundError(f"Partitions not encoded: {missing}")
1984+
19721985
progress_config = core.ProgressConfig(
19731986
total=len(self.schema.columns),
19741987
title="Finalise",
@@ -1986,6 +1999,9 @@ def finalise(self, show_progress=False):
19861999
with core.ParallelWorkManager(0, progress_config) as pwm:
19872000
for name in self.schema.columns:
19882001
pwm.submit(self.finalise_array, name)
2002+
logger.debug(f"Removing {self.wip_path}")
2003+
shutil.rmtree(self.wip_path)
2004+
logger.info("Consolidating Zarr metadata")
19892005
zarr.consolidate_metadata(self.path)
19902006

19912007
######################

tests/test_vcf.py

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import xarray.testing as xt
66
import zarr
77

8-
from bio2zarr import vcf, vcf_utils
8+
from bio2zarr import core, vcf, vcf_utils
99

1010

1111
@pytest.fixture(scope="module")
@@ -383,3 +383,117 @@ def test_fields(self, schema, field, description):
383383
# ])
384384
# def test_filters(self, schema, filt, description):
385385
# assert schema["filters"][field]["description"] == description
386+
387+
388+
class TestVcfZarrWriterExample:
389+
arrays = (
390+
"variant_contig",
391+
"variant_filter",
392+
"variant_id",
393+
"variant_AA",
394+
"variant_AC",
395+
"variant_AF",
396+
"variant_AN",
397+
"variant_DB",
398+
"variant_DP",
399+
"variant_H2",
400+
"variant_NS",
401+
"variant_position",
402+
"variant_quality",
403+
"variant_allele",
404+
"call_DP",
405+
"call_GQ",
406+
"call_genotype",
407+
"call_genotype_phased",
408+
"call_genotype_mask",
409+
"call_HQ",
410+
)
411+
412+
def test_init_paths(self, icf_path, tmp_path):
413+
zarr_path = tmp_path / "x.zarr"
414+
assert not zarr_path.exists()
415+
num_partitions, _ = vcf.encode_init(
416+
icf_path, zarr_path, 7, variants_chunk_size=3
417+
)
418+
assert num_partitions == 3
419+
assert zarr_path.exists()
420+
wip_path = zarr_path / "wip"
421+
assert wip_path.exists()
422+
wip_partitions_path = wip_path / "partitions"
423+
assert wip_partitions_path.exists()
424+
wip_arrays_path = wip_path / "arrays"
425+
assert wip_arrays_path.exists()
426+
for name in self.arrays:
427+
array_path = wip_arrays_path / name
428+
assert array_path.exists()
429+
with open(wip_path / "metadata.json") as f:
430+
d = json.loads(f.read())
431+
# Basic test
432+
assert len(d["partitions"]) == 3
433+
434+
def test_finalise_paths(self, icf_path, tmp_path):
435+
zarr_path = tmp_path / "x.zarr"
436+
assert not zarr_path.exists()
437+
num_partitions, _ = vcf.encode_init(
438+
icf_path, zarr_path, 7, variants_chunk_size=3
439+
)
440+
wip_path = zarr_path / "wip"
441+
assert wip_path.exists()
442+
for j in range(num_partitions):
443+
vcf.encode_partition(zarr_path, j)
444+
assert (wip_path / "partitions" / f"p{j}").exists()
445+
vcf.encode_finalise(zarr_path)
446+
assert zarr_path.exists()
447+
assert not wip_path.exists()
448+
449+
def test_finalise_no_partitions_fails(self, icf_path, tmp_path):
450+
zarr_path = tmp_path / "x.zarr"
451+
vcf.encode_init(icf_path, zarr_path, 3, variants_chunk_size=3)
452+
with pytest.raises(
453+
FileNotFoundError, match="Partitions not encoded: \\[0, 1, 2\\]"
454+
):
455+
vcf.encode_finalise(zarr_path)
456+
457+
@pytest.mark.parametrize("partition", [0, 1, 2])
458+
def test_finalise_missing_partition_fails(self, icf_path, tmp_path, partition):
459+
zarr_path = tmp_path / "x.zarr"
460+
vcf.encode_init(icf_path, zarr_path, 3, variants_chunk_size=3)
461+
for j in range(3):
462+
if j != partition:
463+
vcf.encode_partition(zarr_path, j)
464+
with pytest.raises(
465+
FileNotFoundError, match=f"Partitions not encoded: \\[{partition}\\]"
466+
):
467+
vcf.encode_finalise(zarr_path)
468+
469+
@pytest.mark.parametrize("partition", [0, 1, 2])
470+
def test_encode_partition(self, icf_path, tmp_path, partition):
471+
zarr_path = tmp_path / "x.zarr"
472+
vcf.encode_init(icf_path, zarr_path, 3, variants_chunk_size=3)
473+
partition_path = zarr_path / "wip" / "partitions" / f"p{partition}"
474+
assert not partition_path.exists()
475+
vcf.encode_partition(zarr_path, partition)
476+
assert partition_path.exists()
477+
478+
def test_double_encode_partition(self, icf_path, tmp_path, caplog):
479+
partition = 1
480+
zarr_path = tmp_path / "x.zarr"
481+
vcf.encode_init(icf_path, zarr_path, 3, variants_chunk_size=3)
482+
partition_path = zarr_path / "wip" / "partitions" / f"p{partition}"
483+
assert not partition_path.exists()
484+
vcf.encode_partition(zarr_path, partition)
485+
assert partition_path.exists()
486+
size = core.du(partition_path)
487+
assert size > 0
488+
with caplog.at_level("WARNING"):
489+
vcf.encode_partition(zarr_path, partition)
490+
assert "Removing existing partition at" in caplog.text
491+
assert partition_path.exists()
492+
assert core.du(partition_path) == size
493+
494+
@pytest.mark.parametrize("partition", [-1, 3, 100])
495+
def test_encode_partition_out_of_range(self, icf_path, tmp_path, partition):
496+
zarr_path = tmp_path / "x.zarr"
497+
vcf.encode_init(icf_path, zarr_path, 3, variants_chunk_size=3)
498+
with pytest.raises(ValueError, match="Partition index must be in the range"):
499+
vcf.encode_partition(zarr_path, partition)

0 commit comments

Comments
 (0)