Skip to content

Commit b5ea524

Browse files
Remove redundant dictionary in Schema format
Breaking change for ongoing encode operations
1 parent 13c24f0 commit b5ea524

File tree

3 files changed

+73
-61
lines changed

3 files changed

+73
-61
lines changed

bio2zarr/vcf.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,6 +1311,7 @@ def __post_init__(self):
13111311
self.shape = tuple(self.shape)
13121312
self.chunks = tuple(self.chunks)
13131313
self.dimensions = tuple(self.dimensions)
1314+
self.filters = tuple(self.filters)
13141315

13151316
@staticmethod
13161317
def new(**kwargs):
@@ -1404,7 +1405,7 @@ def variant_chunk_nbytes(self):
14041405
return chunk_items * dt.itemsize
14051406

14061407

1407-
ZARR_SCHEMA_FORMAT_VERSION = "0.3"
1408+
ZARR_SCHEMA_FORMAT_VERSION = "0.4"
14081409

14091410

14101411
@dataclasses.dataclass
@@ -1416,7 +1417,10 @@ class VcfZarrSchema:
14161417
samples: list
14171418
contigs: list
14181419
filters: list
1419-
fields: dict
1420+
fields: list
1421+
1422+
def field_map(self):
1423+
return {field.name: field for field in self.fields}
14201424

14211425
def asdict(self):
14221426
return dataclasses.asdict(self)
@@ -1435,9 +1439,7 @@ def fromdict(d):
14351439
ret.samples = [Sample(**sd) for sd in d["samples"]]
14361440
ret.contigs = [Contig(**sd) for sd in d["contigs"]]
14371441
ret.filters = [Filter(**sd) for sd in d["filters"]]
1438-
ret.fields = {
1439-
key: ZarrColumnSpec(**value) for key, value in d["fields"].items()
1440-
}
1442+
ret.fields = [ZarrColumnSpec(**sd) for sd in d["fields"]]
14411443
return ret
14421444

14431445
@staticmethod
@@ -1572,7 +1574,7 @@ def fixed_field_spec(
15721574
format_version=ZARR_SCHEMA_FORMAT_VERSION,
15731575
samples_chunk_size=samples_chunk_size,
15741576
variants_chunk_size=variants_chunk_size,
1575-
fields={col.name: col for col in colspecs},
1577+
fields=colspecs,
15761578
dimensions=["variants", "samples", "ploidy", "alleles", "filters"],
15771579
samples=icf.metadata.samples,
15781580
contigs=icf.metadata.contigs,
@@ -1701,6 +1703,12 @@ def schema(self):
17011703
def num_partitions(self):
17021704
return len(self.metadata.partitions)
17031705

1706+
def has_genotypes(self):
1707+
for field in self.schema.fields:
1708+
if field.name == "call_genotype":
1709+
return True
1710+
return False
1711+
17041712
#######################
17051713
# init
17061714
#######################
@@ -1760,7 +1768,7 @@ def init(
17601768
root = zarr.group(store=store)
17611769

17621770
total_chunks = 0
1763-
for field in self.schema.fields.values():
1771+
for field in self.schema.fields:
17641772
a = self.init_array(root, field, partitions[-1].stop)
17651773
total_chunks += a.nchunks
17661774

@@ -1880,10 +1888,10 @@ def encode_partition(self, partition_index):
18801888
self.encode_filters_partition(partition_index)
18811889
self.encode_contig_partition(partition_index)
18821890
self.encode_alleles_partition(partition_index)
1883-
for col in self.schema.fields.values():
1891+
for col in self.schema.fields:
18841892
if col.vcf_field is not None:
18851893
self.encode_array_partition(col, partition_index)
1886-
if "call_genotype" in self.schema.fields:
1894+
if self.has_genotypes():
18871895
self.encode_genotypes_partition(partition_index)
18881896

18891897
final_path = self.partition_path(partition_index)
@@ -2100,8 +2108,8 @@ def finalise(self, show_progress=False):
21002108
# for multiple workers, or making a standard wrapper for tqdm
21012109
# that allows us to have a consistent look and feel.
21022110
with core.ParallelWorkManager(0, progress_config) as pwm:
2103-
for name in self.schema.fields:
2104-
pwm.submit(self.finalise_array, name)
2111+
for field in self.schema.fields:
2112+
pwm.submit(self.finalise_array, field.name)
21052113
logger.debug(f"Removing {self.wip_path}")
21062114
shutil.rmtree(self.wip_path)
21072115
logger.info("Consolidating Zarr metadata")
@@ -2116,17 +2124,14 @@ def get_max_encoding_memory(self):
21162124
Return the approximate maximum memory used to encode a variant chunk.
21172125
"""
21182126
max_encoding_mem = 0
2119-
for col in self.schema.fields.values():
2127+
for col in self.schema.fields:
21202128
max_encoding_mem = max(max_encoding_mem, col.variant_chunk_nbytes)
21212129
gt_mem = 0
2122-
if "call_genotype" in self.schema.fields:
2123-
encoded_together = [
2124-
"call_genotype",
2125-
"call_genotype_phased",
2126-
"call_genotype_mask",
2127-
]
2130+
if self.has_genotypes:
21282131
gt_mem = sum(
2129-
self.schema.fields[col].variant_chunk_nbytes for col in encoded_together
2132+
field.variant_chunk_nbytes
2133+
for field in self.schema.fields
2134+
if field.name.startswith("call_genotype")
21302135
)
21312136
return max(max_encoding_mem, gt_mem)
21322137

@@ -2158,7 +2163,7 @@ def encode_all_partitions(
21582163
num_workers = min(max_num_workers, worker_processes)
21592164

21602165
total_bytes = 0
2161-
for col in self.schema.fields.values():
2166+
for col in self.schema.fields:
21622167
# Open the array definition to get the total size
21632168
total_bytes += zarr.open(self.arrays_path / col.name).nbytes
21642169

tests/test_icf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def schema(self, icf):
228228
],
229229
)
230230
def test_info_schemas(self, schema, name, dtype, shape, dimensions):
231-
v = schema.fields[name]
231+
v = schema.field_map()[name]
232232
assert v.dtype == dtype
233233
assert tuple(v.shape) == shape
234234
assert v.dimensions == dimensions

tests/test_vcf.py

Lines changed: 47 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def schema_path(icf_path, tmp_path_factory):
3232
@pytest.fixture(scope="module")
3333
def schema(schema_path):
3434
with open(schema_path) as f:
35-
return json.load(f)
35+
return vcf.VcfZarrSchema.fromjson(f.read())
3636

3737

3838
@pytest.fixture(scope="module")
@@ -83,7 +83,7 @@ def test_not_enough_memory_for_two(
8383
class TestJsonVersions:
8484
@pytest.mark.parametrize("version", ["0.1", "1.0", "xxxxx", 0.2])
8585
def test_zarr_schema_mismatch(self, schema, version):
86-
d = dict(schema)
86+
d = schema.asdict()
8787
d["format_version"] = version
8888
with pytest.raises(ValueError, match="Zarr schema format version mismatch"):
8989
vcf.VcfZarrSchema.fromdict(d)
@@ -156,13 +156,13 @@ def test_generated_no_samples(self, icf_path):
156156
def test_generated_change_dtype(self, icf_path):
157157
icf = vcf.IntermediateColumnarFormat(icf_path)
158158
schema = vcf.VcfZarrSchema.generate(icf)
159-
schema.fields["variant_position"].dtype = "i8"
159+
schema.field_map()["variant_position"].dtype = "i8"
160160
self.assert_json_round_trip(schema)
161161

162162
def test_generated_change_compressor(self, icf_path):
163163
icf = vcf.IntermediateColumnarFormat(icf_path)
164164
schema = vcf.VcfZarrSchema.generate(icf)
165-
schema.fields["variant_position"].compressor = {"cname": "FAKE"}
165+
schema.field_map()["variant_position"].compressor = {"cname": "FAKE"}
166166
self.assert_json_round_trip(schema)
167167

168168

@@ -174,7 +174,7 @@ def test_codec(self, tmp_path, icf_path, cname, clevel, shuffle):
174174
zarr_path = tmp_path / "zarr"
175175
icf = vcf.IntermediateColumnarFormat(icf_path)
176176
schema = vcf.VcfZarrSchema.generate(icf)
177-
for var in schema.fields.values():
177+
for var in schema.fields:
178178
var.compressor["cname"] = cname
179179
var.compressor["clevel"] = clevel
180180
var.compressor["shuffle"] = shuffle
@@ -183,7 +183,7 @@ def test_codec(self, tmp_path, icf_path, cname, clevel, shuffle):
183183
f.write(schema.asjson())
184184
vcf.encode(icf_path, zarr_path, schema_path=schema_path)
185185
root = zarr.open(zarr_path)
186-
for var in schema.fields.values():
186+
for var in schema.fields:
187187
a = root[var.name]
188188
assert a.compressor.cname == cname
189189
assert a.compressor.clevel == clevel
@@ -194,7 +194,7 @@ def test_genotype_dtype(self, tmp_path, icf_path, dtype):
194194
zarr_path = tmp_path / "zarr"
195195
icf = vcf.IntermediateColumnarFormat(icf_path)
196196
schema = vcf.VcfZarrSchema.generate(icf)
197-
schema.fields["call_genotype"].dtype = dtype
197+
schema.field_map()["call_genotype"].dtype = dtype
198198
schema_path = tmp_path / "schema"
199199
with open(schema_path, "w") as f:
200200
f.write(schema.asjson())
@@ -203,16 +203,23 @@ def test_genotype_dtype(self, tmp_path, icf_path, dtype):
203203
assert root["call_genotype"].dtype == dtype
204204

205205

206+
def get_field_dict(a_schema, name):
207+
d = a_schema.asdict()
208+
for field in d["fields"]:
209+
if field["name"] == name:
210+
return field
211+
212+
206213
class TestDefaultSchema:
207214
def test_format_version(self, schema):
208-
assert schema["format_version"] == vcf.ZARR_SCHEMA_FORMAT_VERSION
215+
assert schema.format_version == vcf.ZARR_SCHEMA_FORMAT_VERSION
209216

210217
def test_chunk_size(self, schema):
211-
assert schema["samples_chunk_size"] == 1000
212-
assert schema["variants_chunk_size"] == 10000
218+
assert schema.samples_chunk_size == 1000
219+
assert schema.variants_chunk_size == 10000
213220

214221
def test_dimensions(self, schema):
215-
assert schema["dimensions"] == [
222+
assert schema.dimensions == [
216223
"variants",
217224
"samples",
218225
"ploidy",
@@ -221,29 +228,29 @@ def test_dimensions(self, schema):
221228
]
222229

223230
def test_samples(self, schema):
224-
assert schema["samples"] == [
231+
assert schema.asdict()["samples"] == [
225232
{"id": s} for s in ["NA00001", "NA00002", "NA00003"]
226233
]
227234

228235
def test_contigs(self, schema):
229-
assert schema["contigs"] == [
236+
assert schema.asdict()["contigs"] == [
230237
{"id": s, "length": None} for s in ["19", "20", "X"]
231238
]
232239

233240
def test_filters(self, schema):
234-
assert schema["filters"] == [
241+
assert schema.asdict()["filters"] == [
235242
{"id": "PASS", "description": "All filters passed"},
236243
{"id": "s50", "description": "Less than 50% of samples have data"},
237244
{"id": "q10", "description": "Quality below 10"},
238245
]
239246

240247
def test_variant_contig(self, schema):
241-
assert schema["fields"]["variant_contig"] == {
248+
assert get_field_dict(schema, "variant_contig") == {
242249
"name": "variant_contig",
243250
"dtype": "i1",
244-
"shape": [9],
245-
"chunks": [10000],
246-
"dimensions": ["variants"],
251+
"shape": (9,),
252+
"chunks": (10000,),
253+
"dimensions": ("variants",),
247254
"description": "",
248255
"vcf_field": None,
249256
"compressor": {
@@ -253,16 +260,16 @@ def test_variant_contig(self, schema):
253260
"shuffle": 0,
254261
"blocksize": 0,
255262
},
256-
"filters": [],
263+
"filters": tuple(),
257264
}
258265

259266
def test_call_genotype(self, schema):
260-
assert schema["fields"]["call_genotype"] == {
267+
assert get_field_dict(schema, "call_genotype") == {
261268
"name": "call_genotype",
262269
"dtype": "i1",
263-
"shape": [9, 3, 2],
264-
"chunks": [10000, 1000],
265-
"dimensions": ["variants", "samples", "ploidy"],
270+
"shape": (9, 3, 2),
271+
"chunks": (10000, 1000),
272+
"dimensions": ("variants", "samples", "ploidy"),
266273
"description": "",
267274
"vcf_field": None,
268275
"compressor": {
@@ -272,16 +279,16 @@ def test_call_genotype(self, schema):
272279
"shuffle": 2,
273280
"blocksize": 0,
274281
},
275-
"filters": [],
282+
"filters": tuple(),
276283
}
277284

278285
def test_call_genotype_mask(self, schema):
279-
assert schema["fields"]["call_genotype_mask"] == {
286+
assert get_field_dict(schema, "call_genotype_mask") == {
280287
"name": "call_genotype_mask",
281288
"dtype": "bool",
282-
"shape": [9, 3, 2],
283-
"chunks": [10000, 1000],
284-
"dimensions": ["variants", "samples", "ploidy"],
289+
"shape": (9, 3, 2),
290+
"chunks": (10000, 1000),
291+
"dimensions": ("variants", "samples", "ploidy"),
285292
"description": "",
286293
"vcf_field": None,
287294
"compressor": {
@@ -291,16 +298,16 @@ def test_call_genotype_mask(self, schema):
291298
"shuffle": 2,
292299
"blocksize": 0,
293300
},
294-
"filters": [],
301+
"filters": tuple(),
295302
}
296303

297304
def test_call_genotype_phased(self, schema):
298-
assert schema["fields"]["call_genotype_mask"] == {
305+
assert get_field_dict(schema, "call_genotype_mask") == {
299306
"name": "call_genotype_mask",
300307
"dtype": "bool",
301-
"shape": [9, 3, 2],
302-
"chunks": [10000, 1000],
303-
"dimensions": ["variants", "samples", "ploidy"],
308+
"shape": (9, 3, 2),
309+
"chunks": (10000, 1000),
310+
"dimensions": ("variants", "samples", "ploidy"),
304311
"description": "",
305312
"vcf_field": None,
306313
"compressor": {
@@ -310,16 +317,16 @@ def test_call_genotype_phased(self, schema):
310317
"shuffle": 2,
311318
"blocksize": 0,
312319
},
313-
"filters": [],
320+
"filters": tuple(),
314321
}
315322

316323
def test_call_GQ(self, schema):
317-
assert schema["fields"]["call_GQ"] == {
324+
assert get_field_dict(schema, "call_GQ") == {
318325
"name": "call_GQ",
319326
"dtype": "i1",
320-
"shape": [9, 3],
321-
"chunks": [10000, 1000],
322-
"dimensions": ["variants", "samples"],
327+
"shape": (9, 3),
328+
"chunks": (10000, 1000),
329+
"dimensions": ("variants", "samples"),
323330
"description": "Genotype Quality",
324331
"vcf_field": "FORMAT/GQ",
325332
"compressor": {
@@ -329,7 +336,7 @@ def test_call_GQ(self, schema):
329336
"shuffle": 0,
330337
"blocksize": 0,
331338
},
332-
"filters": [],
339+
"filters": tuple(),
333340
}
334341

335342

@@ -379,7 +386,7 @@ class TestVcfDescriptions:
379386
],
380387
)
381388
def test_fields(self, schema, field, description):
382-
assert schema["fields"][field]["description"] == description
389+
assert schema.field_map()[field].description == description
383390

384391
# This information is not in the schema yet,
385392
# https://github.com/sgkit-dev/bio2zarr/issues/123

0 commit comments

Comments
 (0)