Skip to content

Commit c2fde10

Browse files
Merge pull request #222 from jeromekelleher/error-on-2g-chunks
Raise an error when too-large chunks are encountered
2 parents a1662c2 + a465728 commit c2fde10

File tree

2 files changed

+108
-22
lines changed

2 files changed

+108
-22
lines changed

bio2zarr/vcf2zarr/vcz.py

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def inspect(path):
3434

3535

3636
@dataclasses.dataclass
37-
class ZarrColumnSpec:
37+
class ZarrArraySpec:
3838
name: str
3939
dtype: str
4040
shape: tuple
@@ -54,7 +54,7 @@ def __post_init__(self):
5454

5555
@staticmethod
5656
def new(**kwargs):
57-
spec = ZarrColumnSpec(
57+
spec = ZarrArraySpec(
5858
**kwargs, compressor=DEFAULT_ZARR_COMPRESSOR.get_config(), filters=[]
5959
)
6060
spec._choose_compressor_settings()
@@ -94,7 +94,7 @@ def from_field(
9494
dimensions.append("genotypes")
9595
else:
9696
dimensions.append(f"{vcf_field.category}_{vcf_field.name}_dim")
97-
return ZarrColumnSpec.new(
97+
return ZarrArraySpec.new(
9898
vcf_field=vcf_field.full_name,
9999
name=variable_name,
100100
dtype=vcf_field.smallest_dtype(),
@@ -127,6 +127,23 @@ def _choose_compressor_settings(self):
127127

128128
self.compressor["shuffle"] = shuffle
129129

130+
@property
131+
def chunk_nbytes(self):
132+
"""
133+
Returns the nbytes for a single chunk in this array.
134+
"""
135+
items = 1
136+
dim = 0
137+
for chunk_size in self.chunks:
138+
size = min(chunk_size, self.shape[dim])
139+
items *= size
140+
dim += 1
141+
# Include sizes for extra dimensions.
142+
for size in self.shape[dim:]:
143+
items *= size
144+
dt = np.dtype(self.dtype)
145+
return items * dt.itemsize
146+
130147
@property
131148
def variant_chunk_nbytes(self):
132149
"""
@@ -157,6 +174,24 @@ class VcfZarrSchema(core.JsonDataclass):
157174
filters: list
158175
fields: list
159176

177+
def validate(self):
178+
"""
179+
Checks that the schema is well-formed and within required limits.
180+
"""
181+
for field in self.fields:
182+
# This is the Blosc max buffer size
183+
if field.chunk_nbytes > 2147483647:
184+
# TODO add some links to documentation here advising how to
185+
# deal with PL values.
186+
raise ValueError(
187+
f"Field {field.name} chunks are too large "
188+
f"({field.chunk_nbytes} > 2**31 - 1 bytes). "
189+
"Either generate a schema and drop this field (if you don't "
190+
"need it) or reduce the variant or sample chunk sizes."
191+
)
192+
# TODO other checks? There must be lots of ways people could mess
193+
# up the schema leading to cryptic errors.
194+
160195
def field_map(self):
161196
return {field.name: field for field in self.fields}
162197

@@ -171,7 +206,7 @@ def fromdict(d):
171206
ret.samples = [icf.Sample(**sd) for sd in d["samples"]]
172207
ret.contigs = [icf.Contig(**sd) for sd in d["contigs"]]
173208
ret.filters = [icf.Filter(**sd) for sd in d["filters"]]
174-
ret.fields = [ZarrColumnSpec(**sd) for sd in d["fields"]]
209+
ret.fields = [ZarrArraySpec(**sd) for sd in d["fields"]]
175210
return ret
176211

177212
@staticmethod
@@ -192,7 +227,7 @@ def generate(icf, variants_chunk_size=None, samples_chunk_size=None):
192227
)
193228

194229
def spec_from_field(field, variable_name=None):
195-
return ZarrColumnSpec.from_field(
230+
return ZarrArraySpec.from_field(
196231
field,
197232
num_samples=n,
198233
num_variants=m,
@@ -204,7 +239,7 @@ def spec_from_field(field, variable_name=None):
204239
def fixed_field_spec(
205240
name, dtype, vcf_field=None, shape=(m,), dimensions=("variants",)
206241
):
207-
return ZarrColumnSpec.new(
242+
return ZarrArraySpec.new(
208243
vcf_field=vcf_field,
209244
name=name,
210245
dtype=dtype,
@@ -230,13 +265,13 @@ def fixed_field_spec(
230265
),
231266
fixed_field_spec(
232267
name="variant_allele",
233-
dtype="str",
268+
dtype="O",
234269
shape=(m, max_alleles),
235270
dimensions=["variants", "alleles"],
236271
),
237272
fixed_field_spec(
238273
name="variant_id",
239-
dtype="str",
274+
dtype="O",
240275
),
241276
fixed_field_spec(
242277
name="variant_id_mask",
@@ -267,7 +302,7 @@ def fixed_field_spec(
267302
chunks = [variants_chunk_size, samples_chunk_size]
268303
dimensions = ["variants", "samples"]
269304
colspecs.append(
270-
ZarrColumnSpec.new(
305+
ZarrArraySpec.new(
271306
vcf_field=None,
272307
name="call_genotype_phased",
273308
dtype="bool",
@@ -280,7 +315,7 @@ def fixed_field_spec(
280315
shape += [ploidy]
281316
dimensions += ["ploidy"]
282317
colspecs.append(
283-
ZarrColumnSpec.new(
318+
ZarrArraySpec.new(
284319
vcf_field=None,
285320
name="call_genotype",
286321
dtype=gt_field.smallest_dtype(),
@@ -291,7 +326,7 @@ def fixed_field_spec(
291326
)
292327
)
293328
colspecs.append(
294-
ZarrColumnSpec.new(
329+
ZarrArraySpec.new(
295330
vcf_field=None,
296331
name="call_genotype_mask",
297332
dtype="bool",
@@ -447,6 +482,7 @@ def init(
447482
self.icf = icf
448483
if self.path.exists():
449484
raise ValueError("Zarr path already exists") # NEEDS TEST
485+
schema.validate()
450486
partitions = VcfZarrPartition.generate_partitions(
451487
self.icf.num_records,
452488
schema.variants_chunk_size,

tests/test_vcz.py

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def test_not_enough_memory(self, tmp_path, icf_path, max_memory):
6767
with pytest.raises(ValueError, match="Insufficient memory"):
6868
vcf2zarr.encode(icf_path, zarr_path, max_memory=max_memory)
6969

70-
@pytest.mark.parametrize("max_memory", ["150KiB", "200KiB"])
70+
@pytest.mark.parametrize("max_memory", ["315KiB", "500KiB"])
7171
def test_not_enough_memory_for_two(
7272
self, tmp_path, icf_path, zarr_path, caplog, max_memory
7373
):
@@ -214,6 +214,55 @@ def get_field_dict(a_schema, name):
214214
return field
215215

216216

217+
class TestChunkNbytes:
218+
@pytest.mark.parametrize(
219+
("field", "value"),
220+
[
221+
("call_genotype", 54), # 9 * 3 * 2 * 1
222+
("call_genotype_phased", 27),
223+
("call_genotype_mask", 54),
224+
("variant_position", 36), # 9 * 4
225+
("variant_H2", 9),
226+
("variant_AC", 18), # 9 * 2
227+
# Object fields have an itemsize of 8
228+
("variant_AA", 72), # 9 * 8
229+
("variant_allele", 9 * 4 * 8),
230+
],
231+
)
232+
def test_example_schema(self, schema, field, value):
233+
field = schema.field_map()[field]
234+
assert field.chunk_nbytes == value
235+
236+
def test_chunk_size(self, icf_path, tmp_path):
237+
store = vcf2zarr.IntermediateColumnarFormat(icf_path)
238+
schema = vcf2zarr.VcfZarrSchema.generate(
239+
store, samples_chunk_size=2, variants_chunk_size=3
240+
)
241+
fields = schema.field_map()
242+
assert fields["call_genotype"].chunk_nbytes == 3 * 2 * 2
243+
assert fields["variant_position"].chunk_nbytes == 3 * 4
244+
assert fields["variant_AC"].chunk_nbytes == 3 * 2
245+
246+
247+
class TestValidateSchema:
248+
@pytest.mark.parametrize("size", [2**31, 2**31 + 1, 2**32])
249+
def test_chunk_too_large(self, schema, size):
250+
schema = vcf2zarr.VcfZarrSchema.fromdict(schema.asdict())
251+
field = schema.field_map()["variant_H2"]
252+
field.shape = (size,)
253+
field.chunks = (size,)
254+
with pytest.raises(ValueError, match="Field variant_H2 chunks are too large"):
255+
schema.validate()
256+
257+
@pytest.mark.parametrize("size", [2**31 - 1, 2**30])
258+
def test_chunk_not_too_large(self, schema, size):
259+
schema = vcf2zarr.VcfZarrSchema.fromdict(schema.asdict())
260+
field = schema.field_map()["variant_H2"]
261+
field.shape = (size,)
262+
field.chunks = (size,)
263+
schema.validate()
264+
265+
217266
class TestDefaultSchema:
218267
def test_format_version(self, schema):
219268
assert schema.format_version == vcz_mod.ZARR_SCHEMA_FORMAT_VERSION
@@ -359,16 +408,17 @@ class TestVcfDescriptions:
359408
def test_fields(self, schema, field, description):
360409
assert schema.field_map()[field].description == description
361410

362-
# This information is not in the schema yet,
363-
# https://github.com/sgkit-dev/vcf2zarr/issues/123
364-
# @pytest.mark.parametrize(
365-
# ("filt", "description"),
366-
# [
367-
# ("s50","Less than 50% of samples have data"),
368-
# ("q10", "Quality below 10"),
369-
# ])
370-
# def test_filters(self, schema, filt, description):
371-
# assert schema["filters"][field]["description"] == description
411+
@pytest.mark.parametrize(
412+
("filt", "description"),
413+
[
414+
("PASS", "All filters passed"),
415+
("s50", "Less than 50% of samples have data"),
416+
("q10", "Quality below 10"),
417+
],
418+
)
419+
def test_filters(self, schema, filt, description):
420+
d = {f.id: f.description for f in schema.filters}
421+
assert d[filt] == description
372422

373423

374424
class TestVcfZarrWriterExample:

0 commit comments

Comments
 (0)