Skip to content

Commit f921ea7

Browse files
Basic tests for region index
Closes #311
1 parent 488150b commit f921ea7

File tree

3 files changed

+27
-1
lines changed

3 files changed

+27
-1
lines changed

bio2zarr/vcf2zarr/vcz.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1139,14 +1139,15 @@ def create_index(self):
11391139
)
11401140
c_start_idx = c_end_idx + 1
11411141

1142-
index = np.array(index, dtype=np.int32)
1142+
index = np.array(index, dtype=pos.dtype)
11431143
kwargs = {}
11441144
if not zarr_v3():
11451145
kwargs["dimension_separator"] = self.metadata.dimension_separator
11461146
array = root.array(
11471147
"region_index",
11481148
data=index,
11491149
shape=index.shape,
1150+
chunks=index.shape,
11501151
dtype=index.dtype,
11511152
compressor=numcodecs.Blosc("zstd", clevel=9, shuffle=0),
11521153
**kwargs,

tests/test_vcf_examples.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,17 @@ def test_vcf_dimensions(self, ds):
438438
def test_vcf_field_description(self, ds, field, description):
439439
assert ds[field].attrs["description"] == description
440440

441+
def test_region_index(self, ds):
442+
assert ds["region_index"].chunks == ((3,), (6,))
443+
region_index = np.array(
444+
[
445+
[0, 0, 111, 112, 112, 2],
446+
[0, 1, 14370, 1235237, 1235237, 6],
447+
[0, 2, 10, 10, 11, 1],
448+
]
449+
)
450+
nt.assert_array_equal(ds["region_index"], region_index)
451+
441452

442453
class TestSmallExampleLocalAlleles:
443454
data_path = "tests/data/vcf/sample.vcf.gz"

tests/test_vcz.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,20 @@ def test_genotype_dtype(self, tmp_path, icf_path, dtype):
250250
root = zarr.open(zarr_path)
251251
assert root["call_genotype"].dtype == dtype
252252

253+
@pytest.mark.parametrize("dtype", ["i4", "i8"])
254+
def test_region_index_dtype(self, tmp_path, icf_path, dtype):
255+
zarr_path = tmp_path / "zarr"
256+
icf = vcf2zarr.IntermediateColumnarFormat(icf_path)
257+
schema = vcf2zarr.VcfZarrSchema.generate(icf)
258+
schema.field_map()["variant_position"].dtype = dtype
259+
schema_path = tmp_path / "schema"
260+
with open(schema_path, "w") as f:
261+
f.write(schema.asjson())
262+
vcf2zarr.encode(icf_path, zarr_path, schema_path=schema_path)
263+
root = zarr.open(zarr_path)
264+
assert root["variant_position"].dtype == dtype
265+
assert root["region_index"].dtype == dtype
266+
253267

254268
def get_field_dict(a_schema, name):
255269
d = a_schema.asdict()

0 commit comments

Comments
 (0)