From 77dfbc96f769e8b9a26168b28681b7382abd75df Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Mon, 21 Apr 2025 01:31:03 +0100 Subject: [PATCH] Add defaults to schema --- bio2zarr/icf.py | 18 +++++-- bio2zarr/plink.py | 13 +++-- bio2zarr/vcz.py | 76 +++++++++++++-------------- tests/test_vcz.py | 131 ++++++++++++++++++++++++++++++++++------------ 4 files changed, 156 insertions(+), 82 deletions(-) diff --git a/bio2zarr/icf.py b/bio2zarr/icf.py index 887b8533..6be03e17 100644 --- a/bio2zarr/icf.py +++ b/bio2zarr/icf.py @@ -843,7 +843,7 @@ def convert_local_allele_field_types(fields): chunks = gt.chunks[:-1] dimensions = gt.dimensions[:-1] - la = vcz.ZarrArraySpec.new( + la = vcz.ZarrArraySpec( name="call_LA", dtype="i1", shape=gt.shape, @@ -1064,7 +1064,12 @@ def fixed_field_spec( dimensions=("variants",), chunks=None, ): - return vcz.ZarrArraySpec.new( + compressor = ( + vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config() + if dtype == "bool" + else None + ) + return vcz.ZarrArraySpec( source=source, name=name, dtype=dtype, @@ -1072,6 +1077,7 @@ def fixed_field_spec( description="", dimensions=dimensions, chunks=chunks or [schema_instance.variants_chunk_size], + compressor=compressor, ) alt_field = self.fields["ALT"] @@ -1135,7 +1141,7 @@ def fixed_field_spec( ] dimensions = ["variants", "samples"] array_specs.append( - vcz.ZarrArraySpec.new( + vcz.ZarrArraySpec( name="call_genotype_phased", dtype="bool", shape=list(shape), @@ -1148,23 +1154,25 @@ def fixed_field_spec( chunks += [ploidy] dimensions += ["ploidy"] array_specs.append( - vcz.ZarrArraySpec.new( + vcz.ZarrArraySpec( name="call_genotype", dtype=gt_field.smallest_dtype(), shape=list(shape), chunks=list(chunks), dimensions=list(dimensions), description="", + compressor=vcz.DEFAULT_ZARR_COMPRESSOR_GENOTYPES.get_config(), ) ) array_specs.append( - vcz.ZarrArraySpec.new( + vcz.ZarrArraySpec( name="call_genotype_mask", dtype="bool", shape=list(shape), chunks=list(chunks), dimensions=list(dimensions), description="", + compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(), ) ) diff --git a/bio2zarr/plink.py b/bio2zarr/plink.py index 800aa820..cb333219 100644 --- a/bio2zarr/plink.py +++ b/bio2zarr/plink.py @@ -82,7 +82,7 @@ def generate_schema( ) array_specs = [ - vcz.ZarrArraySpec.new( + vcz.ZarrArraySpec( source="position", name="variant_position", dtype="i4", @@ -91,7 +91,7 @@ def generate_schema( chunks=[schema_instance.variants_chunk_size], description=None, ), - vcz.ZarrArraySpec.new( + vcz.ZarrArraySpec( name="variant_allele", dtype="O", shape=[m, 2], @@ -99,7 +99,7 @@ def generate_schema( chunks=[schema_instance.variants_chunk_size, 2], description=None, ), - vcz.ZarrArraySpec.new( + vcz.ZarrArraySpec( name="call_genotype_phased", dtype="bool", shape=[m, n], @@ -109,8 +109,9 @@ def generate_schema( schema_instance.samples_chunk_size, ], description=None, + compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(), ), - vcz.ZarrArraySpec.new( + vcz.ZarrArraySpec( name="call_genotype", dtype="i1", shape=[m, n, 2], @@ -121,8 +122,9 @@ def generate_schema( 2, ], description=None, + compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(), ), - vcz.ZarrArraySpec.new( + vcz.ZarrArraySpec( name="call_genotype_mask", dtype="bool", shape=[m, n, 2], @@ -133,6 +135,7 @@ def generate_schema( 2, ], description=None, + compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(), ), ] schema_instance.fields = array_specs diff --git a/bio2zarr/vcz.py b/bio2zarr/vcz.py index 94022327..622db76c 100644 --- a/bio2zarr/vcz.py +++ b/bio2zarr/vcz.py @@ -16,6 +16,12 @@ ZARR_SCHEMA_FORMAT_VERSION = "0.5" DEFAULT_ZARR_COMPRESSOR = numcodecs.Blosc(cname="zstd", clevel=7) +DEFAULT_ZARR_COMPRESSOR_GENOTYPES = numcodecs.Blosc( + cname="zstd", clevel=7, shuffle=numcodecs.Blosc.BITSHUFFLE +) +DEFAULT_ZARR_COMPRESSOR_BOOL = numcodecs.Blosc( + cname="zstd", clevel=7, shuffle=numcodecs.Blosc.BITSHUFFLE +) _fixed_field_descriptions = { "variant_contig": "An identifier from the reference genome or an angle-bracketed ID" @@ -93,8 +99,8 @@ class ZarrArraySpec: chunks: tuple dimensions: tuple description: str - compressor: dict - filters: list + compressor: dict = None + filters: list = None source: str = None def __post_init__(self): @@ -105,15 +111,7 @@ def __post_init__(self): self.shape = tuple(self.shape) self.chunks = tuple(self.chunks) self.dimensions = tuple(self.dimensions) - self.filters = tuple(self.filters) - - @staticmethod - def new(**kwargs): - spec = ZarrArraySpec( - **kwargs, compressor=DEFAULT_ZARR_COMPRESSOR.get_config(), filters=[] - ) - spec._choose_compressor_settings() - return spec + self.filters = tuple(self.filters) if self.filters is not None else None @staticmethod def from_field( @@ -124,6 +122,8 @@ def from_field( variants_chunk_size, samples_chunk_size, array_name=None, + compressor=None, + filters=None, ): shape = [num_variants] prefix = "variant_" @@ -150,7 +150,7 @@ def from_field( dimensions.append("genotypes") else: dimensions.append(f"{vcf_field.category}_{vcf_field.name}_dim") - return ZarrArraySpec.new( + return ZarrArraySpec( source=vcf_field.full_name, name=array_name, dtype=vcf_field.smallest_dtype(), @@ -158,31 +158,10 @@ def from_field( chunks=chunks, dimensions=dimensions, description=vcf_field.description, + compressor=compressor, + filters=filters, ) - def _choose_compressor_settings(self): - """ - Choose compressor and filter settings based on the size and - type of the array, plus some hueristics from observed properties - of VCFs. - - See https://github.com/pystatgen/bio2zarr/discussions/74 - """ - # Default is to not shuffle, because autoshuffle isn't recognised - # by many Zarr implementations, and shuffling can lead to worse - # performance in some cases anyway. Turning on shuffle should be a - # deliberate choice. - shuffle = numcodecs.Blosc.NOSHUFFLE - if self.name == "call_genotype" and self.dtype == "i1": - # call_genotype gets BITSHUFFLE by default as it gets - # significantly better compression (at a cost of slower - # decoding) - shuffle = numcodecs.Blosc.BITSHUFFLE - elif self.dtype == "bool": - shuffle = numcodecs.Blosc.BITSHUFFLE - - self.compressor["shuffle"] = shuffle - @property def chunk_nbytes(self): """ @@ -240,6 +219,7 @@ class VcfZarrSchema(core.JsonDataclass): samples_chunk_size: int variants_chunk_size: int fields: list + defaults: dict def __init__( self, @@ -247,9 +227,16 @@ def __init__( fields: list, variants_chunk_size: int = None, samples_chunk_size: int = None, + defaults: dict = None, ): self.format_version = format_version self.fields = fields + defaults = defaults.copy() if defaults is not None else {} + if defaults.get("compressor", None) is None: + defaults["compressor"] = DEFAULT_ZARR_COMPRESSOR.get_config() + if defaults.get("filters", None) is None: + defaults["filters"] = [] + self.defaults = defaults if variants_chunk_size is None: variants_chunk_size = 1000 self.variants_chunk_size = variants_chunk_size @@ -533,7 +520,7 @@ def init( total_chunks = 0 for field in self.schema.fields: - a = self.init_array(root, field, partitions[-1].stop) + a = self.init_array(root, self.metadata.schema, field, partitions[-1].stop) total_chunks += a.nchunks logger.info("Writing WIP metadata") @@ -600,9 +587,20 @@ def encode_filters(self, root): ) array.attrs["_ARRAY_DIMENSIONS"] = ["filters"] - def init_array(self, root, array_spec, variants_dim_size): + def init_array(self, root, schema, array_spec, variants_dim_size): kwargs = dict(zarr_utils.ZARR_FORMAT_KWARGS) - filters = [numcodecs.get_codec(filt) for filt in array_spec.filters] + filters = ( + array_spec.filters + if array_spec.filters is not None + else schema.defaults["filters"] + ) + filters = [numcodecs.get_codec(filt) for filt in filters] + compressor = ( + array_spec.compressor + if array_spec.compressor is not None + else schema.defaults["compressor"] + ) + compressor = numcodecs.get_codec(compressor) if array_spec.dtype == "O": if zarr_utils.zarr_v3(): filters = [*list(filters), numcodecs.VLenUTF8()] @@ -620,7 +618,7 @@ def init_array(self, root, array_spec, variants_dim_size): shape=shape, chunks=array_spec.chunks, dtype=array_spec.dtype, - compressor=numcodecs.get_codec(array_spec.compressor), + compressor=compressor, filters=filters, **kwargs, ) diff --git a/tests/test_vcz.py b/tests/test_vcz.py index 1c30a5a9..e76196b5 100644 --- a/tests/test_vcz.py +++ b/tests/test_vcz.py @@ -226,10 +226,14 @@ def test_codec(self, tmp_path, icf_path, cname, clevel, shuffle): zarr_path = tmp_path / "zarr" icf = icf_mod.IntermediateColumnarFormat(icf_path) schema = icf.generate_schema() + field_changed = False for array_spec in schema.fields: - array_spec.compressor["cname"] = cname - array_spec.compressor["clevel"] = clevel - array_spec.compressor["shuffle"] = shuffle + if array_spec.compressor is not None: + array_spec.compressor["cname"] = cname + array_spec.compressor["clevel"] = clevel + array_spec.compressor["shuffle"] = shuffle + field_changed = True + assert field_changed schema_path = tmp_path / "schema" with open(schema_path, "w") as f: f.write(schema.asjson()) @@ -237,9 +241,10 @@ def test_codec(self, tmp_path, icf_path, cname, clevel, shuffle): root = zarr.open(zarr_path) for array_spec in schema.fields: a = root[array_spec.name] - assert a.compressor.cname == cname - assert a.compressor.clevel == clevel - assert a.compressor.shuffle == shuffle + if array_spec.compressor is not None: + assert a.compressor.cname == cname + assert a.compressor.clevel == clevel + assert a.compressor.shuffle == shuffle @pytest.mark.parametrize("dtype", ["i4", "i8"]) def test_genotype_dtype(self, tmp_path, icf_path, dtype): @@ -341,14 +346,8 @@ def test_variant_contig(self, schema): "description": "An identifier from the reference genome or an " "angle-bracketed ID string pointing to a contig in the assembly file", "source": None, - "compressor": { - "id": "blosc", - "cname": "zstd", - "clevel": 7, - "shuffle": 0, - "blocksize": 0, - }, - "filters": tuple(), + "compressor": None, + "filters": None, } def test_call_genotype(self, schema): @@ -367,7 +366,7 @@ def test_call_genotype(self, schema): "shuffle": 2, "blocksize": 0, }, - "filters": tuple(), + "filters": None, } def test_call_genotype_mask(self, schema): @@ -386,7 +385,7 @@ def test_call_genotype_mask(self, schema): "shuffle": 2, "blocksize": 0, }, - "filters": tuple(), + "filters": None, } def test_call_genotype_phased(self, schema): @@ -405,7 +404,7 @@ def test_call_genotype_phased(self, schema): "shuffle": 2, "blocksize": 0, }, - "filters": tuple(), + "filters": None, } def test_call_GQ(self, schema): @@ -417,14 +416,8 @@ def test_call_GQ(self, schema): "dimensions": ("variants", "samples"), "description": "Genotype Quality", "source": "FORMAT/GQ", - "compressor": { - "id": "blosc", - "cname": "zstd", - "clevel": 7, - "shuffle": 0, - "blocksize": 0, - }, - "filters": tuple(), + "compressor": None, + "filters": None, } @@ -447,14 +440,8 @@ def test_call_LA(self, local_alleles_schema): "0-based indices into REF+ALT, indicating which alleles" " are relevant (local) for the current sample" ), - "compressor": { - "id": "blosc", - "cname": "zstd", - "clevel": 7, - "shuffle": 0, - "blocksize": 0, - }, - "filters": tuple(), + "compressor": None, + "filters": None, } @@ -754,3 +741,81 @@ def test_no_such_path(self, bad_path): def test_unknown_format(self, path): with pytest.raises(ValueError, match="not in ICF or VCF Zarr format"): icf_mod.inspect(path) + + +class TestSchemaDefaults: + def test_default_compressor_and_filters(self, schema): + assert "compressor" in schema.defaults + assert schema.defaults["compressor"] == vcz.DEFAULT_ZARR_COMPRESSOR.get_config() + assert "filters" in schema.defaults + assert schema.defaults["filters"] == [] + + def test_custom_defaults(self, icf_path): + custom_defaults = { + "compressor": {"id": "blosc", "cname": "lz4", "clevel": 3, "shuffle": 1}, + "filters": [{"id": "delta", "dtype": "