Skip to content

Commit c3250a9

Browse files
tomwhitejeromekelleher
authored andcommitted
Exclude region_index from dataset comparison when chunking differs
1 parent b4e16e0 commit c3250a9

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

tests/test_vcf_examples.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@
1212
from bio2zarr import constants, provenance, vcf2zarr
1313

1414

15+
def assert_dataset_equal(ds1, ds2, drop_vars=None):
16+
if drop_vars is None:
17+
xt.assert_equal(ds1, ds2)
18+
else:
19+
xt.assert_equal(ds1.drop_vars(drop_vars), ds2.drop_vars(drop_vars))
20+
21+
1522
class TestSmallExample:
1623
data_path = "tests/data/vcf/sample.vcf.gz"
1724

@@ -273,7 +280,7 @@ def test_chunk_size(
273280
ds2 = sg.load_dataset(out)
274281
# print(ds2.call_genotype.values)
275282
# print(ds.call_genotype.values)
276-
xt.assert_equal(ds, ds2)
283+
assert_dataset_equal(ds, ds2, drop_vars=["region_index"])
277284
assert ds2.call_DP.chunks == (y_chunks, x_chunks)
278285
assert ds2.call_GQ.chunks == (y_chunks, x_chunks)
279286
assert ds2.call_HQ.chunks == (y_chunks, x_chunks, (2,))
@@ -341,8 +348,10 @@ def test_max_variant_chunks(
341348
max_variant_chunks=max_variant_chunks,
342349
)
343350
ds2 = sg.load_dataset(out)
344-
xt.assert_equal(
345-
ds.isel(variants=slice(None, variants_chunk_size * max_variant_chunks)), ds2
351+
assert_dataset_equal(
352+
ds.isel(variants=slice(None, variants_chunk_size * max_variant_chunks)),
353+
ds2,
354+
drop_vars=["region_index"],
346355
)
347356

348357
@pytest.mark.parametrize("worker_processes", [0, 1, 2])
@@ -355,7 +364,7 @@ def test_worker_processes(self, ds, tmp_path, worker_processes):
355364
worker_processes=worker_processes,
356365
)
357366
ds2 = sg.load_dataset(out)
358-
xt.assert_equal(ds, ds2)
367+
assert_dataset_equal(ds, ds2, drop_vars=["region_index"])
359368

360369
def test_inspect(self, tmp_path):
361370
# TODO pretty weak test, we should be doing this better somewhere else
@@ -391,8 +400,8 @@ def test_missing_contig_vcf(self, ds, tmp_path, path):
391400
ds_c1 = ds.isel(variants=ds["variant_contig"].values == id1)
392401
id2 = contig_id_2.index(contig)
393402
ds_c2 = ds2.isel(variants=ds2["variant_contig"].values == id2)
394-
drop_vars = ["contig_id", "variant_contig"]
395-
xt.assert_equal(ds_c1.drop_vars(drop_vars), ds_c2.drop_vars(drop_vars))
403+
drop_vars = ["contig_id", "variant_contig", "region_index"]
404+
assert_dataset_equal(ds_c1, ds_c2, drop_vars=drop_vars)
396405

397406
def test_vcf_dimensions(self, ds):
398407
assert ds.call_genotype.dims == ("variants", "samples", "ploidy")

0 commit comments

Comments
 (0)