Skip to content

Commit ebbbe3d

Browse files
Merge pull request #189 from jeromekelleher/schema-mod-tests
Schema mod tests
2 parents 42e2a5a + 4799b49 commit ebbbe3d

File tree

5 files changed

+153
-87
lines changed

5 files changed

+153
-87
lines changed

bio2zarr/cli.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,17 +423,25 @@ def dencode_finalise(zarr_path, verbose):
423423
@click.command(name="convert")
424424
@vcfs
425425
@new_zarr_path
426+
@force
426427
@variants_chunk_size
427428
@samples_chunk_size
428429
@verbose
429430
@worker_processes
430431
def convert_vcf(
431-
vcfs, zarr_path, variants_chunk_size, samples_chunk_size, verbose, worker_processes
432+
vcfs,
433+
zarr_path,
434+
force,
435+
variants_chunk_size,
436+
samples_chunk_size,
437+
verbose,
438+
worker_processes,
432439
):
433440
"""
434441
Convert input VCF(s) directly to vcfzarr (not recommended for large files).
435442
"""
436443
setup_logging(verbose)
444+
check_overwrite_dir(zarr_path, force)
437445
vcf.convert(
438446
vcfs,
439447
zarr_path,

bio2zarr/vcf.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,6 +1311,7 @@ def __post_init__(self):
13111311
self.shape = tuple(self.shape)
13121312
self.chunks = tuple(self.chunks)
13131313
self.dimensions = tuple(self.dimensions)
1314+
self.filters = tuple(self.filters)
13141315

13151316
@staticmethod
13161317
def new(**kwargs):
@@ -1396,27 +1397,29 @@ def variant_chunk_nbytes(self):
13961397
for size in self.shape[1:]:
13971398
chunk_items *= size
13981399
dt = np.dtype(self.dtype)
1399-
if dt.kind == "O":
1400+
if dt.kind == "O" and "samples" in self.dimensions:
14001401
logger.warning(
14011402
f"Field {self.name} is a string; max memory usage may "
14021403
"be a significant underestimate"
14031404
)
14041405
return chunk_items * dt.itemsize
14051406

14061407

1407-
ZARR_SCHEMA_FORMAT_VERSION = "0.3"
1408+
ZARR_SCHEMA_FORMAT_VERSION = "0.4"
14081409

14091410

14101411
@dataclasses.dataclass
14111412
class VcfZarrSchema:
14121413
format_version: str
14131414
samples_chunk_size: int
14141415
variants_chunk_size: int
1415-
dimensions: list
14161416
samples: list
14171417
contigs: list
14181418
filters: list
1419-
fields: dict
1419+
fields: list
1420+
1421+
def field_map(self):
1422+
return {field.name: field for field in self.fields}
14201423

14211424
def asdict(self):
14221425
return dataclasses.asdict(self)
@@ -1435,9 +1438,7 @@ def fromdict(d):
14351438
ret.samples = [Sample(**sd) for sd in d["samples"]]
14361439
ret.contigs = [Contig(**sd) for sd in d["contigs"]]
14371440
ret.filters = [Filter(**sd) for sd in d["filters"]]
1438-
ret.fields = {
1439-
key: ZarrColumnSpec(**value) for key, value in d["fields"].items()
1440-
}
1441+
ret.fields = [ZarrColumnSpec(**sd) for sd in d["fields"]]
14411442
return ret
14421443

14431444
@staticmethod
@@ -1572,8 +1573,7 @@ def fixed_field_spec(
15721573
format_version=ZARR_SCHEMA_FORMAT_VERSION,
15731574
samples_chunk_size=samples_chunk_size,
15741575
variants_chunk_size=variants_chunk_size,
1575-
fields={col.name: col for col in colspecs},
1576-
dimensions=["variants", "samples", "ploidy", "alleles", "filters"],
1576+
fields=colspecs,
15771577
samples=icf.metadata.samples,
15781578
contigs=icf.metadata.contigs,
15791579
filters=icf.metadata.filters,
@@ -1701,6 +1701,12 @@ def schema(self):
17011701
def num_partitions(self):
17021702
return len(self.metadata.partitions)
17031703

1704+
def has_genotypes(self):
1705+
for field in self.schema.fields:
1706+
if field.name == "call_genotype":
1707+
return True
1708+
return False
1709+
17041710
#######################
17051711
# init
17061712
#######################
@@ -1760,7 +1766,7 @@ def init(
17601766
root = zarr.group(store=store)
17611767

17621768
total_chunks = 0
1763-
for field in self.schema.fields.values():
1769+
for field in self.schema.fields:
17641770
a = self.init_array(root, field, partitions[-1].stop)
17651771
total_chunks += a.nchunks
17661772

@@ -1778,9 +1784,7 @@ def init(
17781784

17791785
def encode_samples(self, root):
17801786
if self.schema.samples != self.icf.metadata.samples:
1781-
raise ValueError(
1782-
"Subsetting or reordering samples not supported currently"
1783-
) # NEEDS TEST
1787+
raise ValueError("Subsetting or reordering samples not supported currently")
17841788
array = root.array(
17851789
"sample_id",
17861790
[sample.id for sample in self.schema.samples],
@@ -1880,10 +1884,10 @@ def encode_partition(self, partition_index):
18801884
self.encode_filters_partition(partition_index)
18811885
self.encode_contig_partition(partition_index)
18821886
self.encode_alleles_partition(partition_index)
1883-
for col in self.schema.fields.values():
1887+
for col in self.schema.fields:
18841888
if col.vcf_field is not None:
18851889
self.encode_array_partition(col, partition_index)
1886-
if "call_genotype" in self.schema.fields:
1890+
if self.has_genotypes():
18871891
self.encode_genotypes_partition(partition_index)
18881892

18891893
final_path = self.partition_path(partition_index)
@@ -2100,8 +2104,8 @@ def finalise(self, show_progress=False):
21002104
# for multiple workers, or making a standard wrapper for tqdm
21012105
# that allows us to have a consistent look and feel.
21022106
with core.ParallelWorkManager(0, progress_config) as pwm:
2103-
for name in self.schema.fields:
2104-
pwm.submit(self.finalise_array, name)
2107+
for field in self.schema.fields:
2108+
pwm.submit(self.finalise_array, field.name)
21052109
logger.debug(f"Removing {self.wip_path}")
21062110
shutil.rmtree(self.wip_path)
21072111
logger.info("Consolidating Zarr metadata")
@@ -2116,17 +2120,14 @@ def get_max_encoding_memory(self):
21162120
Return the approximate maximum memory used to encode a variant chunk.
21172121
"""
21182122
max_encoding_mem = 0
2119-
for col in self.schema.fields.values():
2123+
for col in self.schema.fields:
21202124
max_encoding_mem = max(max_encoding_mem, col.variant_chunk_nbytes)
21212125
gt_mem = 0
2122-
if "call_genotype" in self.schema.fields:
2123-
encoded_together = [
2124-
"call_genotype",
2125-
"call_genotype_phased",
2126-
"call_genotype_mask",
2127-
]
2126+
if self.has_genotypes:
21282127
gt_mem = sum(
2129-
self.schema.fields[col].variant_chunk_nbytes for col in encoded_together
2128+
field.variant_chunk_nbytes
2129+
for field in self.schema.fields
2130+
if field.name.startswith("call_genotype")
21302131
)
21312132
return max(max_encoding_mem, gt_mem)
21322133

@@ -2158,7 +2159,7 @@ def encode_all_partitions(
21582159
num_workers = min(max_num_workers, worker_processes)
21592160

21602161
total_bytes = 0
2161-
for col in self.schema.fields.values():
2162+
for col in self.schema.fields:
21622163
# Open the array definition to get the total size
21632164
total_bytes += zarr.open(self.arrays_path / col.name).nbytes
21642165

@@ -2273,7 +2274,7 @@ def convert(
22732274
# TODO add arguments to control location of tmpdir
22742275
):
22752276
with tempfile.TemporaryDirectory(prefix="vcf2zarr") as tmp:
2276-
if_dir = pathlib.Path(tmp) / "if"
2277+
if_dir = pathlib.Path(tmp) / "icf"
22772278
explode(
22782279
if_dir,
22792280
vcfs,

tests/test_cli.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@
4747

4848
DEFAULT_DENCODE_FINALISE_ARGS = dict(show_progress=True)
4949

50+
DEFAULT_CONVERT_ARGS = dict(
51+
variants_chunk_size=None,
52+
samples_chunk_size=None,
53+
show_progress=True,
54+
worker_processes=1,
55+
)
56+
5057

5158
@dataclasses.dataclass
5259
class FakeWorkSummary:
@@ -508,11 +515,24 @@ def test_convert_vcf(self, mocked):
508515
mocked.assert_called_once_with(
509516
(self.vcf_path,),
510517
"zarr_path",
511-
variants_chunk_size=None,
512-
samples_chunk_size=None,
513-
worker_processes=1,
514-
show_progress=True,
518+
**DEFAULT_CONVERT_ARGS,
519+
)
520+
521+
@pytest.mark.parametrize("response", ["n", "N", "No"])
522+
@mock.patch("bio2zarr.vcf.convert")
523+
def test_vcf_convert_overwrite_zarr_confirm_no(self, mocked, tmp_path, response):
524+
zarr_path = tmp_path / "zarr"
525+
zarr_path.mkdir()
526+
runner = ct.CliRunner(mix_stderr=False)
527+
result = runner.invoke(
528+
cli.vcf2zarr,
529+
f"convert {self.vcf_path} {zarr_path}",
530+
catch_exceptions=False,
531+
input=response,
515532
)
533+
assert result.exit_code == 1
534+
assert "Aborted" in result.stderr
535+
mocked.assert_not_called()
516536

517537
@mock.patch("bio2zarr.plink.convert")
518538
def test_convert_plink(self, mocked):
@@ -523,13 +543,25 @@ def test_convert_plink(self, mocked):
523543
assert result.exit_code == 0
524544
assert len(result.stdout) == 0
525545
assert len(result.stderr) == 0
546+
mocked.assert_called_once_with("in", "out", **DEFAULT_CONVERT_ARGS)
547+
548+
@pytest.mark.parametrize("response", ["y", "Y", "yes"])
549+
@mock.patch("bio2zarr.vcf.convert")
550+
def test_vcf_convert_overwrite_zarr_confirm_yes(self, mocked, tmp_path, response):
551+
zarr_path = tmp_path / "zarr"
552+
zarr_path.mkdir()
553+
runner = ct.CliRunner(mix_stderr=False)
554+
result = runner.invoke(
555+
cli.vcf2zarr,
556+
f"convert {self.vcf_path} {zarr_path}",
557+
catch_exceptions=False,
558+
input=response,
559+
)
560+
assert result.exit_code == 0
561+
assert f"Do you want to overwrite {zarr_path}" in result.stdout
562+
assert len(result.stderr) == 0
526563
mocked.assert_called_once_with(
527-
"in",
528-
"out",
529-
worker_processes=1,
530-
samples_chunk_size=None,
531-
variants_chunk_size=None,
532-
show_progress=True,
564+
(self.vcf_path,), str(zarr_path), **DEFAULT_CONVERT_ARGS
533565
)
534566

535567

tests/test_icf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def schema(self, icf):
228228
],
229229
)
230230
def test_info_schemas(self, schema, name, dtype, shape, dimensions):
231-
v = schema.fields[name]
231+
v = schema.field_map()[name]
232232
assert v.dtype == dtype
233233
assert tuple(v.shape) == shape
234234
assert v.dimensions == dimensions

0 commit comments

Comments
 (0)