1212from bio2zarr import constants , provenance , vcf2zarr
1313
1414
15+ def assert_dataset_equal (ds1 , ds2 , drop_vars = None ):
16+ if drop_vars is None :
17+ xt .assert_equal (ds1 , ds2 )
18+ else :
19+ xt .assert_equal (ds1 .drop_vars (drop_vars ), ds2 .drop_vars (drop_vars ))
20+
21+
1522class TestSmallExample :
1623 data_path = "tests/data/vcf/sample.vcf.gz"
1724
@@ -273,7 +280,7 @@ def test_chunk_size(
273280 ds2 = sg .load_dataset (out )
274281 # print(ds2.call_genotype.values)
275282 # print(ds.call_genotype.values)
276- xt . assert_equal (ds , ds2 )
283+ assert_dataset_equal (ds , ds2 , drop_vars = [ "region_index" ] )
277284 assert ds2 .call_DP .chunks == (y_chunks , x_chunks )
278285 assert ds2 .call_GQ .chunks == (y_chunks , x_chunks )
279286 assert ds2 .call_HQ .chunks == (y_chunks , x_chunks , (2 ,))
@@ -341,8 +348,10 @@ def test_max_variant_chunks(
341348 max_variant_chunks = max_variant_chunks ,
342349 )
343350 ds2 = sg .load_dataset (out )
344- xt .assert_equal (
345- ds .isel (variants = slice (None , variants_chunk_size * max_variant_chunks )), ds2
351+ assert_dataset_equal (
352+ ds .isel (variants = slice (None , variants_chunk_size * max_variant_chunks )),
353+ ds2 ,
354+ drop_vars = ["region_index" ],
346355 )
347356
348357 @pytest .mark .parametrize ("worker_processes" , [0 , 1 , 2 ])
@@ -355,7 +364,7 @@ def test_worker_processes(self, ds, tmp_path, worker_processes):
355364 worker_processes = worker_processes ,
356365 )
357366 ds2 = sg .load_dataset (out )
358- xt . assert_equal (ds , ds2 )
367+ assert_dataset_equal (ds , ds2 , drop_vars = [ "region_index" ] )
359368
360369 def test_inspect (self , tmp_path ):
361370 # TODO pretty weak test, we should be doing this better somewhere else
@@ -391,8 +400,8 @@ def test_missing_contig_vcf(self, ds, tmp_path, path):
391400 ds_c1 = ds .isel (variants = ds ["variant_contig" ].values == id1 )
392401 id2 = contig_id_2 .index (contig )
393402 ds_c2 = ds2 .isel (variants = ds2 ["variant_contig" ].values == id2 )
394- drop_vars = ["contig_id" , "variant_contig" ]
395- xt . assert_equal (ds_c1 . drop_vars ( drop_vars ) , ds_c2 . drop_vars ( drop_vars ) )
403+ drop_vars = ["contig_id" , "variant_contig" , "region_index" ]
404+ assert_dataset_equal (ds_c1 , ds_c2 , drop_vars = drop_vars )
396405
397406 def test_vcf_dimensions (self , ds ):
398407 assert ds .call_genotype .dims == ("variants" , "samples" , "ploidy" )
0 commit comments