diff --git a/bio2zarr/plink.py b/bio2zarr/plink.py index 112f77aa..69ffc863 100644 --- a/bio2zarr/plink.py +++ b/bio2zarr/plink.py @@ -73,19 +73,14 @@ def generate_schema( n = self.bed.iid_count m = self.bed.sid_count logging.info(f"Scanned plink with {n} samples and {m} variants") - - # Define dimensions with sizes and chunk sizes - dimensions = { - "variants": vcz.VcfZarrDimension( - size=m, chunk_size=variants_chunk_size or vcz.DEFAULT_VARIANT_CHUNK_SIZE - ), - "samples": vcz.VcfZarrDimension( - size=n, chunk_size=samples_chunk_size or vcz.DEFAULT_SAMPLE_CHUNK_SIZE - ), - "ploidy": vcz.VcfZarrDimension(size=2), - "alleles": vcz.VcfZarrDimension(size=2), - } - + dimensions = vcz.standard_dimensions( + variants_size=m, + variants_chunk_size=variants_chunk_size, + samples_size=n, + samples_chunk_size=samples_chunk_size, + ploidy_size=2, + alleles_size=2, + ) schema_instance = vcz.VcfZarrSchema( format_version=vcz.ZARR_SCHEMA_FORMAT_VERSION, dimensions=dimensions, diff --git a/bio2zarr/tskit.py b/bio2zarr/tskit.py index a5241389..2e442461 100644 --- a/bio2zarr/tskit.py +++ b/bio2zarr/tskit.py @@ -135,18 +135,14 @@ def generate_schema( logging.info( f"Maximum ploidy: {self.max_ploidy}, maximum alleles: {max_alleles}" ) - - dimensions = { - "variants": vcz.VcfZarrDimension( - size=m, chunk_size=variants_chunk_size or vcz.DEFAULT_VARIANT_CHUNK_SIZE - ), - "samples": vcz.VcfZarrDimension( - size=n, chunk_size=samples_chunk_size or vcz.DEFAULT_SAMPLE_CHUNK_SIZE - ), - "ploidy": vcz.VcfZarrDimension(size=self.max_ploidy), - "alleles": vcz.VcfZarrDimension(size=max_alleles), - } - + dimensions = vcz.standard_dimensions( + variants_size=m, + variants_chunk_size=variants_chunk_size, + samples_size=n, + samples_chunk_size=samples_chunk_size, + ploidy_size=self.max_ploidy, + alleles_size=max_alleles, + ) schema_instance = vcz.VcfZarrSchema( format_version=vcz.ZARR_SCHEMA_FORMAT_VERSION, dimensions=dimensions, diff --git a/bio2zarr/vcf.py b/bio2zarr/vcf.py index 252dbd27..0726dcdc 100644 --- a/bio2zarr/vcf.py +++ b/bio2zarr/vcf.py @@ -858,8 +858,8 @@ def convert_local_allele_field_types(fields, schema_instance): " are relevant (local) for the current sample" ), ) - schema_instance.dimensions["local_alleles"] = vcz.VcfZarrDimension( - size=schema_instance.dimensions["ploidy"].size + schema_instance.dimensions["local_alleles"] = vcz.VcfZarrDimension.unchunked( + schema_instance.dimensions["ploidy"].size ) ad = fields_by_name.get("call_AD", None) @@ -869,7 +869,9 @@ def convert_local_allele_field_types(fields, schema_instance): ad.source = None ad.dimensions = (*dimensions, "local_alleles_AD") ad.description += " (local-alleles)" - schema_instance.dimensions["local_alleles_AD"] = vcz.VcfZarrDimension(size=2) + schema_instance.dimensions["local_alleles_AD"] = vcz.VcfZarrDimension.unchunked( + 2 + ) pl = fields_by_name.get("call_PL", None) if pl is not None: @@ -879,7 +881,7 @@ def convert_local_allele_field_types(fields, schema_instance): pl.description += " (local-alleles)" pl.dimensions = (*dimensions, "local_" + pl.dimensions[-1].split("_")[-1]) schema_instance.dimensions["local_" + pl.dimensions[-1].split("_")[-1]] = ( - vcz.VcfZarrDimension(size=3) + vcz.VcfZarrDimension.unchunked(3) ) return [*fields, la] @@ -1057,40 +1059,41 @@ def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles): def generate_schema( self, variants_chunk_size=None, samples_chunk_size=None, local_alleles=None ): - m = self.num_records - n = self.num_samples if local_alleles is None: local_alleles = False max_alleles = max(self.fields["ALT"].vcf_field.summary.max_number + 1, 2) - dimensions = { - "variants": vcz.VcfZarrDimension( - size=m, chunk_size=variants_chunk_size or vcz.DEFAULT_VARIANT_CHUNK_SIZE - ), - "samples": vcz.VcfZarrDimension( - size=n, chunk_size=samples_chunk_size or vcz.DEFAULT_SAMPLE_CHUNK_SIZE - ), - # ploidy and genotypes added conditionally below - "alleles": vcz.VcfZarrDimension(size=max_alleles), - "alt_alleles": vcz.VcfZarrDimension(size=max_alleles - 1), - "filters": vcz.VcfZarrDimension(size=self.metadata.num_filters), - } # Add ploidy and genotypes dimensions only when needed max_genotypes = 0 for field in self.metadata.format_fields: if field.vcf_number == "G": max_genotypes = max(max_genotypes, field.summary.max_number) + + ploidy = None + genotypes_size = None if self.gt_field is not None: ploidy = max(self.gt_field.summary.max_number - 1, 1) - dimensions["ploidy"] = vcz.VcfZarrDimension(size=ploidy) - max_genotypes = math.comb(max_alleles + ploidy - 1, ploidy) - dimensions["genotypes"] = vcz.VcfZarrDimension(size=max_genotypes) + # NOTE: it's not clear why we're computing this, when we must have had + # at least one number=G field to require it anyway? + genotypes_size = math.comb(max_alleles + ploidy - 1, ploidy) + # assert max_genotypes == genotypes_size else: if max_genotypes > 0: # there is no GT field, but there is at least one Number=G field, # so need to define genotypes dimension - dimensions["genotypes"] = vcz.VcfZarrDimension(size=max_genotypes) + genotypes_size = max_genotypes + + dimensions = vcz.standard_dimensions( + variants_size=self.num_records, + variants_chunk_size=variants_chunk_size, + samples_size=self.num_samples, + samples_chunk_size=samples_chunk_size, + alleles_size=max_alleles, + filters_size=self.metadata.num_filters, + ploidy_size=ploidy, + genotypes_size=genotypes_size, + ) schema_instance = vcz.VcfZarrSchema( format_version=vcz.ZARR_SCHEMA_FORMAT_VERSION, @@ -1173,7 +1176,7 @@ def fixed_field_spec(name, dtype, source=None, dimensions=("variants",)): continue array_specs.append(spec_from_field(field)) - if self.gt_field is not None and n > 0: + if self.gt_field is not None and self.num_samples > 0: array_specs.append( vcz.ZarrArraySpec( name="call_genotype_phased", diff --git a/bio2zarr/vcz.py b/bio2zarr/vcz.py index 634947ca..14d3ba19 100644 --- a/bio2zarr/vcz.py +++ b/bio2zarr/vcz.py @@ -102,24 +102,60 @@ def generate_schema(self, variants_chunk_size, samples_chunk_size, local_alleles @dataclasses.dataclass class VcfZarrDimension: size: int - chunk_size: int = None - - def __post_init__(self): - if self.chunk_size is None: - self.chunk_size = self.size + chunk_size: int def asdict(self): - result = {"size": self.size} - if self.chunk_size != self.size: - result["chunk_size"] = self.chunk_size - return result + return dataclasses.asdict(self) @classmethod def fromdict(cls, d): - return cls( - size=d["size"], - chunk_size=d.get("chunk_size", d["size"]), - ) + return cls(**d) + + @classmethod + def unchunked(cls, size): + return cls(size, max(size, 1)) + + +def standard_dimensions( + *, + variants_size, + samples_size, + variants_chunk_size=None, + samples_chunk_size=None, + alleles_size=None, + filters_size=None, + ploidy_size=None, + genotypes_size=None, +): + """ + Returns a dictionary mapping dimension names to definition for the standard + fields in a VCF. + """ + if variants_chunk_size is None: + variants_chunk_size = max(1, min(variants_size, DEFAULT_VARIANT_CHUNK_SIZE)) + if samples_chunk_size is None: + samples_chunk_size = max(1, min(samples_size, DEFAULT_SAMPLE_CHUNK_SIZE)) + + dimensions = { + "variants": VcfZarrDimension(variants_size, variants_chunk_size), + "samples": VcfZarrDimension(samples_size, samples_chunk_size), + } + + if alleles_size is not None: + dimensions["alleles"] = VcfZarrDimension.unchunked(alleles_size) + if alleles_size > 1: + dimensions["alt_alleles"] = VcfZarrDimension.unchunked(alleles_size - 1) + + if filters_size is not None: + dimensions["filters"] = VcfZarrDimension.unchunked(filters_size) + + if ploidy_size is not None: + dimensions["ploidy"] = VcfZarrDimension.unchunked(ploidy_size) + + if genotypes_size is not None: + dimensions["genotypes"] = VcfZarrDimension.unchunked(genotypes_size) + + return dimensions @dataclasses.dataclass @@ -210,8 +246,8 @@ def from_field( elif max_number > 1 or vcf_field.full_name == "FORMAT/LAA": dimensions.append(f"{vcf_field.category}_{vcf_field.name}_dim") if dimensions[-1] not in schema.dimensions: - schema.dimensions[dimensions[-1]] = VcfZarrDimension( - size=vcf_field.max_number + schema.dimensions[dimensions[-1]] = VcfZarrDimension.unchunked( + vcf_field.max_number ) return ZarrArraySpec( @@ -284,7 +320,7 @@ def __init__( self, format_version: str, fields: list, - dimensions: dict = None, + dimensions: dict, defaults: dict = None, ): self.format_version = format_version @@ -295,15 +331,6 @@ def __init__( if defaults.get("filters", None) is None: defaults["filters"] = [] self.defaults = defaults - if dimensions is None: - dimensions = { - "variants": VcfZarrDimension( - size=0, chunk_size=DEFAULT_VARIANT_CHUNK_SIZE - ), - "samples": VcfZarrDimension( - size=0, chunk_size=DEFAULT_SAMPLE_CHUNK_SIZE - ), - } self.dimensions = dimensions def get_shape(self, dimensions): @@ -349,7 +376,9 @@ def fromdict(d): ret = VcfZarrSchema(**d) ret.fields = [ZarrArraySpec(**sd) for sd in d["fields"]] - ret.dimensions = {k: VcfZarrDimension(**v) for k, v in d["dimensions"].items()} + ret.dimensions = { + k: VcfZarrDimension.fromdict(v) for k, v in d["dimensions"].items() + } return ret @@ -1117,7 +1146,6 @@ def __init__(self, path): def create_index(self): """Create an index to support efficient region queries.""" root = zarr.open_group(store=self.path, mode="r+") - print(list(root.keys())) if ( "variant_contig" not in root or "variant_position" not in root diff --git a/tests/test_vcz.py b/tests/test_vcz.py index 7a5d8c31..4f60e6ed 100644 --- a/tests/test_vcz.py +++ b/tests/test_vcz.py @@ -187,8 +187,8 @@ def test_chunk_sizes(self, icf_path, samples_chunk_size, variants_chunk_size): def test_default_chunk_size(self, icf_path): icf = vcf_mod.IntermediateColumnarFormat(icf_path) schema = icf.generate_schema() - assert schema.dimensions["samples"].chunk_size == 10_000 - assert schema.dimensions["variants"].chunk_size == 1000 + assert schema.dimensions["samples"].chunk_size == 3 + assert schema.dimensions["variants"].chunk_size == 9 class TestSchemaJsonRoundTrip: @@ -329,7 +329,6 @@ def test_chunk_not_too_large(self, schema, size): field = schema.field_map()["variant_H2"] schema.dimensions[field.dimensions[-1]].size = size schema.dimensions[field.dimensions[-1]].chunk_size = size - print(schema.dimensions) schema.validate() @@ -338,8 +337,8 @@ def test_format_version(self, schema): assert schema.format_version == vcz.ZARR_SCHEMA_FORMAT_VERSION def test_chunk_size(self, schema): - assert schema.dimensions["samples"].chunk_size == 10000 - assert schema.dimensions["variants"].chunk_size == 1000 + assert schema.dimensions["samples"].chunk_size == 3 + assert schema.dimensions["variants"].chunk_size == 9 def test_variant_contig(self, schema): assert get_field_dict(schema, "variant_contig") == { @@ -752,6 +751,7 @@ def test_custom_defaults(self, icf_path): schema = vcz.VcfZarrSchema( format_version=vcz.ZARR_SCHEMA_FORMAT_VERSION, fields=[], + dimensions={}, defaults=custom_defaults, ) @@ -762,6 +762,7 @@ def test_partial_defaults(self, icf_path): schema1 = vcz.VcfZarrSchema( format_version=vcz.ZARR_SCHEMA_FORMAT_VERSION, fields=[], + dimensions={}, defaults={"compressor": {"id": "blosc", "cname": "zlib", "clevel": 5}}, ) assert schema1.defaults["compressor"] == { @@ -775,6 +776,7 @@ def test_partial_defaults(self, icf_path): schema2 = vcz.VcfZarrSchema( format_version=vcz.ZARR_SCHEMA_FORMAT_VERSION, fields=[], + dimensions={}, defaults={"filters": [{"id": "delta"}]}, ) assert ( @@ -820,27 +822,21 @@ def test_dimension_initialization(self): assert dim1.size == 100 assert dim1.chunk_size == 20 - # Test with only size (chunk_size should default to size) - dim2 = vcz.VcfZarrDimension(size=50) - assert dim2.size == 50 - assert dim2.chunk_size == 50 + def test_unchunked(self): + dim = vcz.VcfZarrDimension.unchunked(50) + assert dim.size == 50 + assert dim.chunk_size == 50 - def test_asdict(self): - # When chunk_size equals size, it shouldn't be included in dict - dim1 = vcz.VcfZarrDimension(size=100, chunk_size=100) - assert dim1.asdict() == {"size": 100} + def test_unchunked_zero_size(self): + dim = vcz.VcfZarrDimension.unchunked(0) + assert dim.size == 0 + assert dim.chunk_size == 1 - # When chunk_size differs from size, it should be included in dict - dim2 = vcz.VcfZarrDimension(size=100, chunk_size=20) - assert dim2.asdict() == {"size": 100, "chunk_size": 20} + def test_asdict(self): + dim1 = vcz.VcfZarrDimension(size=100, chunk_size=101) + assert dim1.asdict() == {"size": 100, "chunk_size": 101} def test_fromdict(self): - # With only size - dim1 = vcz.VcfZarrDimension.fromdict({"size": 75}) - assert dim1.size == 75 - assert dim1.chunk_size == 75 - - # With both size and chunk_size dim2 = vcz.VcfZarrDimension.fromdict({"size": 75, "chunk_size": 25}) assert dim2.size == 75 assert dim2.chunk_size == 25 @@ -899,6 +895,98 @@ def test_max_number_exceeds_dimension_size( vcz.ZarrArraySpec.from_field(vcf_field, schema) +class TestStandardDimensions: + @pytest.mark.parametrize( + ("size", "chunk_size", "expected_chunk_size"), + [ + (0, None, 1), + (0, 100, 100), + (1, 1, 1), + (1, None, 1), + (1, 10, 10), + (1_001, None, 1_000), + (10**9, None, 1_000), + (999, None, 999), + (1, 100_000, 100_000), + ], + ) + def test_variants(self, size, chunk_size, expected_chunk_size): + dims = vcz.standard_dimensions( + variants_size=size, variants_chunk_size=chunk_size, samples_size=0 + ) + assert dims["variants"] == vcz.VcfZarrDimension(size, expected_chunk_size) + + @pytest.mark.parametrize( + ("size", "chunk_size", "expected_chunk_size"), + [ + (0, None, 1), + (0, 100, 100), + (1, 1, 1), + (1, None, 1), + (1, 10, 10), + (10_001, None, 10_000), + (10**9, None, 10_000), + (9_999, None, 9_999), + (1, 100_000, 100_000), + ], + ) + def test_samples(self, size, chunk_size, expected_chunk_size): + dims = vcz.standard_dimensions( + variants_size=0, samples_size=size, samples_chunk_size=chunk_size + ) + assert dims["samples"] == vcz.VcfZarrDimension(size, expected_chunk_size) + + @pytest.mark.parametrize( + ("kwargs", "expected"), + [ + ( + {"variants_size": 1, "samples_size": 1, "alleles_size": 2}, + { + "variants": {"size": 1, "chunk_size": 1}, + "samples": {"size": 1, "chunk_size": 1}, + "alleles": {"size": 2, "chunk_size": 2}, + "alt_alleles": {"size": 1, "chunk_size": 1}, + }, + ), + ( + {"variants_size": 0, "samples_size": 1, "alleles_size": 1}, + { + "variants": {"size": 0, "chunk_size": 1}, + "samples": {"size": 1, "chunk_size": 1}, + "alleles": {"size": 1, "chunk_size": 1}, + }, + ), + ( + {"variants_size": 0, "samples_size": 1, "alleles_size": 0}, + { + "variants": {"size": 0, "chunk_size": 1}, + "samples": {"size": 1, "chunk_size": 1}, + "alleles": {"size": 0, "chunk_size": 1}, + }, + ), + ( + {"variants_size": 0, "samples_size": 1, "filters_size": 2}, + { + "variants": {"size": 0, "chunk_size": 1}, + "samples": {"size": 1, "chunk_size": 1}, + "filters": {"size": 2, "chunk_size": 2}, + }, + ), + ], + ) + def test_examples(self, kwargs, expected): + dims = {k: v.asdict() for k, v in vcz.standard_dimensions(**kwargs).items()} + assert dims == expected + + @pytest.mark.parametrize("field", ["ploidy", "genotypes"]) + @pytest.mark.parametrize("size", [0, 1, 2]) + def test_simple_fields(self, field, size): + dims = vcz.standard_dimensions( + samples_size=1, variants_size=1, **{f"{field}_size": size} + ) + assert dims[field].asdict() == {"size": size, "chunk_size": max(1, size)} + + def test_create_index_errors(tmp_path): root = zarr.open(tmp_path) root["foobar"] = np.array([1, 2, 3])