Skip to content

Commit 0d9aeb7

Browse files
Merge pull request #116 from jeromekelleher/schema-bug
Fix ColumnSpec initialisation
2 parents 6bbce8b + 0e55447 commit 0d9aeb7

File tree

3 files changed

+96
-15
lines changed

3 files changed

+96
-15
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# 0.0.5 2024-04-XX
2+
3+
- Fix bug in schema handling (compressor settings ignored)
4+
15
# 0.0.4 2024-04-08
26

37
- Fix bug in --max-memory handling, and argument to a string like 10G

bio2zarr/vcf.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -890,7 +890,6 @@ def num_columns(self):
890890
return len(self.columns)
891891

892892

893-
894893
def mkdir_with_progress(path):
895894
logger.debug(f"mkdir f{path}")
896895
# NOTE we may have race-conditions here, I'm not sure. Hopefully allowing
@@ -1226,20 +1225,25 @@ class ZarrColumnSpec:
12261225
dtype: str
12271226
shape: tuple
12281227
chunks: tuple
1229-
dimensions: list
1228+
dimensions: tuple
12301229
description: str
12311230
vcf_field: str
1232-
compressor: dict = None
1233-
filters: list = None
1234-
# TODO add filters
1231+
compressor: dict
1232+
filters: list
12351233

12361234
def __post_init__(self):
1235+
# Ensure these are tuples for ease of comparison and consistency
12371236
self.shape = tuple(self.shape)
12381237
self.chunks = tuple(self.chunks)
12391238
self.dimensions = tuple(self.dimensions)
1240-
self.compressor = DEFAULT_ZARR_COMPRESSOR.get_config()
1241-
self.filters = []
1242-
self._choose_compressor_settings()
1239+
1240+
@staticmethod
1241+
def new(**kwargs):
1242+
spec = ZarrColumnSpec(
1243+
**kwargs, compressor=DEFAULT_ZARR_COMPRESSOR.get_config(), filters=[]
1244+
)
1245+
spec._choose_compressor_settings()
1246+
return spec
12431247

12441248
def _choose_compressor_settings(self):
12451249
"""
@@ -1315,7 +1319,7 @@ def generate(icf, variants_chunk_size=None, samples_chunk_size=None):
13151319
def fixed_field_spec(
13161320
name, dtype, vcf_field=None, shape=(m,), dimensions=("variants",)
13171321
):
1318-
return ZarrColumnSpec(
1322+
return ZarrColumnSpec.new(
13191323
vcf_field=vcf_field,
13201324
name=name,
13211325
dtype=dtype,
@@ -1399,7 +1403,7 @@ def fixed_field_spec(
13991403
else:
14001404
dimensions.append(f"{field.category}_{field.name}_dim")
14011405
variable_name = prefix + field.name
1402-
colspec = ZarrColumnSpec(
1406+
colspec = ZarrColumnSpec.new(
14031407
vcf_field=field.full_name,
14041408
name=variable_name,
14051409
dtype=field.smallest_dtype(),
@@ -1417,7 +1421,7 @@ def fixed_field_spec(
14171421
dimensions = ["variants", "samples"]
14181422

14191423
colspecs.append(
1420-
ZarrColumnSpec(
1424+
ZarrColumnSpec.new(
14211425
vcf_field=None,
14221426
name="call_genotype_phased",
14231427
dtype="bool",
@@ -1430,7 +1434,7 @@ def fixed_field_spec(
14301434
shape += [ploidy]
14311435
dimensions += ["ploidy"]
14321436
colspecs.append(
1433-
ZarrColumnSpec(
1437+
ZarrColumnSpec.new(
14341438
vcf_field=None,
14351439
name="call_genotype",
14361440
dtype=gt_field.smallest_dtype(),
@@ -1441,7 +1445,7 @@ def fixed_field_spec(
14411445
)
14421446
)
14431447
colspecs.append(
1444-
ZarrColumnSpec(
1448+
ZarrColumnSpec.new(
14451449
vcf_field=None,
14461450
name="call_genotype_mask",
14471451
dtype="bool",
@@ -1523,7 +1527,9 @@ def __init__(self, path, icf, schema, dimension_separator=None):
15231527
self.schema = schema
15241528
# Default to using nested directories following the Zarr v3 default.
15251529
# This seems to require version 2.17+ to work properly
1526-
self.dimension_separator = "/" if dimension_separator is None else dimension_separator
1530+
self.dimension_separator = (
1531+
"/" if dimension_separator is None else dimension_separator
1532+
)
15271533
store = zarr.DirectoryStore(self.path)
15281534
self.root = zarr.group(store=store)
15291535

tests/test_vcf.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44
import xarray.testing as xt
55
import sgkit as sg
6+
import zarr
67

78
from bio2zarr import vcf
89

@@ -100,7 +101,6 @@ def test_exploded_metadata_mismatch(self, tmpdir, icf_path, version):
100101

101102

102103
class TestEncodeDimensionSeparator:
103-
104104
@pytest.mark.parametrize("dimension_separator", [None, "/"])
105105
def test_directories(self, tmp_path, icf_path, dimension_separator):
106106
zarr_path = tmp_path / "zarr"
@@ -122,6 +122,77 @@ def test_bad_value(self, tmp_path, icf_path, dimension_separator):
122122
vcf.encode(icf_path, zarr_path, dimension_separator=dimension_separator)
123123

124124

125+
class TestSchemaJsonRoundTrip:
126+
def assert_json_round_trip(self, schema):
127+
schema2 = vcf.VcfZarrSchema.fromjson(schema.asjson())
128+
assert schema == schema2
129+
130+
def test_generated_no_changes(self, icf_path):
131+
icf = vcf.IntermediateColumnarFormat(icf_path)
132+
self.assert_json_round_trip(vcf.VcfZarrSchema.generate(icf))
133+
134+
def test_generated_no_columns(self, icf_path):
135+
icf = vcf.IntermediateColumnarFormat(icf_path)
136+
schema = vcf.VcfZarrSchema.generate(icf)
137+
schema.columns.clear()
138+
self.assert_json_round_trip(schema)
139+
140+
def test_generated_no_samples(self, icf_path):
141+
icf = vcf.IntermediateColumnarFormat(icf_path)
142+
schema = vcf.VcfZarrSchema.generate(icf)
143+
schema.sample_id.clear()
144+
self.assert_json_round_trip(schema)
145+
146+
def test_generated_change_dtype(self, icf_path):
147+
icf = vcf.IntermediateColumnarFormat(icf_path)
148+
schema = vcf.VcfZarrSchema.generate(icf)
149+
schema.columns["variant_position"].dtype = "i8"
150+
self.assert_json_round_trip(schema)
151+
152+
def test_generated_change_compressor(self, icf_path):
153+
icf = vcf.IntermediateColumnarFormat(icf_path)
154+
schema = vcf.VcfZarrSchema.generate(icf)
155+
schema.columns["variant_position"].compressor = {"cname": "FAKE"}
156+
self.assert_json_round_trip(schema)
157+
158+
159+
class TestSchemaEncode:
160+
@pytest.mark.parametrize(
161+
["cname", "clevel", "shuffle"], [("lz4", 1, 0), ("zlib", 7, 1), ("zstd", 4, 2)]
162+
)
163+
def test_codec(self, tmp_path, icf_path, cname, clevel, shuffle):
164+
zarr_path = tmp_path / "zarr"
165+
icf = vcf.IntermediateColumnarFormat(icf_path)
166+
schema = vcf.VcfZarrSchema.generate(icf)
167+
for var in schema.columns.values():
168+
var.compressor["cname"] = cname
169+
var.compressor["clevel"] = clevel
170+
var.compressor["shuffle"] = shuffle
171+
schema_path = tmp_path / "schema"
172+
with open(schema_path, "w") as f:
173+
f.write(schema.asjson())
174+
vcf.encode(icf_path, zarr_path, schema_path=schema_path)
175+
root = zarr.open(zarr_path)
176+
for var in schema.columns.values():
177+
a = root[var.name]
178+
assert a.compressor.cname == cname
179+
assert a.compressor.clevel == clevel
180+
assert a.compressor.shuffle == shuffle
181+
182+
@pytest.mark.parametrize("dtype", ["i4", "i8"])
183+
def test_genotype_dtype(self, tmp_path, icf_path, dtype):
184+
zarr_path = tmp_path / "zarr"
185+
icf = vcf.IntermediateColumnarFormat(icf_path)
186+
schema = vcf.VcfZarrSchema.generate(icf)
187+
schema.columns["call_genotype"].dtype = dtype
188+
schema_path = tmp_path / "schema"
189+
with open(schema_path, "w") as f:
190+
f.write(schema.asjson())
191+
vcf.encode(icf_path, zarr_path, schema_path=schema_path)
192+
root = zarr.open(zarr_path)
193+
assert root["call_genotype"].dtype == dtype
194+
195+
125196
class TestDefaultSchema:
126197
def test_format_version(self, schema):
127198
assert schema["format_version"] == vcf.ZARR_SCHEMA_FORMAT_VERSION

0 commit comments

Comments
 (0)