diff --git a/bio2zarr/vcf2zarr/vcz.py b/bio2zarr/vcf2zarr/vcz.py index 742e9f61..8dfa9776 100644 --- a/bio2zarr/vcf2zarr/vcz.py +++ b/bio2zarr/vcf2zarr/vcz.py @@ -1139,7 +1139,7 @@ def create_index(self): ) c_start_idx = c_end_idx + 1 - index = np.array(index, dtype=np.int32) + index = np.array(index, dtype=pos.dtype) kwargs = {} if not zarr_v3(): kwargs["dimension_separator"] = self.metadata.dimension_separator @@ -1147,8 +1147,10 @@ def create_index(self): "region_index", data=index, shape=index.shape, + chunks=index.shape, dtype=index.dtype, compressor=numcodecs.Blosc("zstd", clevel=9, shuffle=0), + fill_value=None, **kwargs, ) array.attrs["_ARRAY_DIMENSIONS"] = [ diff --git a/tests/test_vcf_examples.py b/tests/test_vcf_examples.py index 3692fe44..06a0aaf5 100644 --- a/tests/test_vcf_examples.py +++ b/tests/test_vcf_examples.py @@ -438,6 +438,17 @@ def test_vcf_dimensions(self, ds): def test_vcf_field_description(self, ds, field, description): assert ds[field].attrs["description"] == description + def test_region_index(self, ds): + assert ds["region_index"].chunks == ((3,), (6,)) + region_index = np.array( + [ + [0, 0, 111, 112, 112, 2], + [0, 1, 14370, 1235237, 1235237, 6], + [0, 2, 10, 10, 11, 1], + ] + ) + nt.assert_array_equal(ds["region_index"], region_index) + class TestSmallExampleLocalAlleles: data_path = "tests/data/vcf/sample.vcf.gz" diff --git a/tests/test_vcz.py b/tests/test_vcz.py index 19c022bd..11e570ad 100644 --- a/tests/test_vcz.py +++ b/tests/test_vcz.py @@ -250,6 +250,20 @@ def test_genotype_dtype(self, tmp_path, icf_path, dtype): root = zarr.open(zarr_path) assert root["call_genotype"].dtype == dtype + @pytest.mark.parametrize("dtype", ["i4", "i8"]) + def test_region_index_dtype(self, tmp_path, icf_path, dtype): + zarr_path = tmp_path / "zarr" + icf = vcf2zarr.IntermediateColumnarFormat(icf_path) + schema = vcf2zarr.VcfZarrSchema.generate(icf) + schema.field_map()["variant_position"].dtype = dtype + schema_path = tmp_path / "schema" + with open(schema_path, "w") as f: + f.write(schema.asjson()) + vcf2zarr.encode(icf_path, zarr_path, schema_path=schema_path) + root = zarr.open(zarr_path) + assert root["variant_position"].dtype == dtype + assert root["region_index"].dtype == dtype + def get_field_dict(a_schema, name): d = a_schema.asdict()