diff --git a/CHANGELOG.md b/CHANGELOG.md index e3d08519..c3f0feb1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,9 @@ Breaking changes - Remove explicit sample, contig and filter lists from the schema. Existing ICFs will need to be recreated. (#343) +- Add dimensions and default compressor and filter settings to the schema. + (#361) + # 0.1.5 2025-03-31 - Add support for merging contig IDs across multiple VCFs (#335) diff --git a/bio2zarr/icf.py b/bio2zarr/icf.py index eee6afca..3ac7425a 100644 --- a/bio2zarr/icf.py +++ b/bio2zarr/icf.py @@ -829,7 +829,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): return False -def convert_local_allele_field_types(fields): +def convert_local_allele_field_types(fields, schema_instance): """ Update the specified list of fields to include the LAA field, and to convert any supported localisable fields to the L* counterpart. @@ -842,45 +842,45 @@ def convert_local_allele_field_types(fields): """ fields_by_name = {field.name: field for field in fields} gt = fields_by_name["call_genotype"] - if gt.shape[-1] != 2: - raise ValueError("Local alleles only supported on diploid data") - # TODO check if LA is already in here + if schema_instance.get_shape(["ploidy"])[0] != 2: + raise ValueError("Local alleles only supported on diploid data") - shape = gt.shape[:-1] - chunks = gt.chunks[:-1] dimensions = gt.dimensions[:-1] la = vcz.ZarrArraySpec( name="call_LA", dtype="i1", - shape=gt.shape, - chunks=gt.chunks, dimensions=(*dimensions, "local_alleles"), description=( "0-based indices into REF+ALT, indicating which alleles" " are relevant (local) for the current sample" ), ) + schema_instance.dimensions["local_alleles"] = vcz.VcfZarrDimension( + size=schema_instance.dimensions["ploidy"].size + ) + ad = fields_by_name.get("call_AD", None) if ad is not None: # TODO check if call_LAD is in the list already ad.name = "call_LAD" ad.source = None - ad.shape = (*shape, 2) - ad.chunks = (*chunks, 2) - ad.dimensions = (*dimensions, "local_alleles") + ad.dimensions = (*dimensions, "local_alleles_AD") ad.description += " (local-alleles)" + schema_instance.dimensions["local_alleles_AD"] = vcz.VcfZarrDimension(size=2) pl = fields_by_name.get("call_PL", None) if pl is not None: # TODO check if call_LPL is in the list already pl.name = "call_LPL" pl.source = None - pl.shape = (*shape, 3) - pl.chunks = (*chunks, 3) pl.description += " (local-alleles)" - pl.dimensions = (*dimensions, "local_" + pl.dimensions[-1]) + pl.dimensions = (*dimensions, "local_" + pl.dimensions[-1].split("_")[-1]) + schema_instance.dimensions["local_" + pl.dimensions[-1].split("_")[-1]] = ( + vcz.VcfZarrDimension(size=3) + ) + return [*fields, la] @@ -1042,36 +1042,40 @@ def generate_schema( if local_alleles is None: local_alleles = False + dimensions = { + "variants": vcz.VcfZarrDimension( + size=m, chunk_size=variants_chunk_size or vcz.DEFAULT_VARIANT_CHUNK_SIZE + ), + "samples": vcz.VcfZarrDimension( + size=n, chunk_size=samples_chunk_size or vcz.DEFAULT_SAMPLE_CHUNK_SIZE + ), + # ploidy added conditionally below + "alleles": vcz.VcfZarrDimension( + size=max(self.fields["ALT"].vcf_field.summary.max_number + 1, 2) + ), + "filters": vcz.VcfZarrDimension(size=self.metadata.num_filters), + } + schema_instance = vcz.VcfZarrSchema( format_version=vcz.ZARR_SCHEMA_FORMAT_VERSION, - samples_chunk_size=samples_chunk_size, - variants_chunk_size=variants_chunk_size, + dimensions=dimensions, fields=[], ) logger.info( "Generating schema with chunks=" - f"{schema_instance.variants_chunk_size, schema_instance.samples_chunk_size}" + f"variants={dimensions['variants'].chunk_size}, " + f"samples={dimensions['samples'].chunk_size}" ) def spec_from_field(field, array_name=None): return vcz.ZarrArraySpec.from_field( field, - num_samples=n, - num_variants=m, - samples_chunk_size=schema_instance.samples_chunk_size, - variants_chunk_size=schema_instance.variants_chunk_size, + schema_instance, array_name=array_name, ) - def fixed_field_spec( - name, - dtype, - source=None, - shape=(m,), - dimensions=("variants",), - chunks=None, - ): + def fixed_field_spec(name, dtype, source=None, dimensions=("variants",)): compressor = ( vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config() if dtype == "bool" @@ -1081,16 +1085,11 @@ def fixed_field_spec( source=source, name=name, dtype=dtype, - shape=shape, description="", dimensions=dimensions, - chunks=chunks or [schema_instance.variants_chunk_size], compressor=compressor, ) - alt_field = self.fields["ALT"] - max_alleles = alt_field.vcf_field.summary.max_number + 1 - array_specs = [ fixed_field_spec( name="variant_contig", @@ -1099,16 +1098,12 @@ def fixed_field_spec( fixed_field_spec( name="variant_filter", dtype="bool", - shape=(m, self.metadata.num_filters), dimensions=["variants", "filters"], - chunks=(schema_instance.variants_chunk_size, self.metadata.num_filters), ), fixed_field_spec( name="variant_allele", dtype="O", - shape=(m, max_alleles), dimensions=["variants", "alleles"], - chunks=(schema_instance.variants_chunk_size, max_alleles), ), fixed_field_spec( name="variant_id", @@ -1142,32 +1137,23 @@ def fixed_field_spec( if gt_field is not None and n > 0: ploidy = max(gt_field.summary.max_number - 1, 1) - shape = [m, n] - chunks = [ - schema_instance.variants_chunk_size, - schema_instance.samples_chunk_size, - ] - dimensions = ["variants", "samples"] + # Add ploidy dimension only when needed + schema_instance.dimensions["ploidy"] = vcz.VcfZarrDimension(size=ploidy) + array_specs.append( vcz.ZarrArraySpec( name="call_genotype_phased", dtype="bool", - shape=list(shape), - chunks=list(chunks), - dimensions=list(dimensions), + dimensions=["variants", "samples"], description="", + compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(), ) ) - shape += [ploidy] - chunks += [ploidy] - dimensions += ["ploidy"] array_specs.append( vcz.ZarrArraySpec( name="call_genotype", dtype=gt_field.smallest_dtype(), - shape=list(shape), - chunks=list(chunks), - dimensions=list(dimensions), + dimensions=["variants", "samples", "ploidy"], description="", compressor=vcz.DEFAULT_ZARR_COMPRESSOR_GENOTYPES.get_config(), ) @@ -1176,16 +1162,14 @@ def fixed_field_spec( vcz.ZarrArraySpec( name="call_genotype_mask", dtype="bool", - shape=list(shape), - chunks=list(chunks), - dimensions=list(dimensions), + dimensions=["variants", "samples", "ploidy"], description="", compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(), ) ) if local_alleles: - array_specs = convert_local_allele_field_types(array_specs) + array_specs = convert_local_allele_field_types(array_specs, schema_instance) schema_instance.fields = array_specs return schema_instance diff --git a/bio2zarr/plink.py b/bio2zarr/plink.py index cb333219..7c672696 100644 --- a/bio2zarr/plink.py +++ b/bio2zarr/plink.py @@ -69,16 +69,28 @@ def generate_schema( m = self.bed.sid_count logging.info(f"Scanned plink with {n} samples and {m} variants") + # Define dimensions with sizes and chunk sizes + dimensions = { + "variants": vcz.VcfZarrDimension( + size=m, chunk_size=variants_chunk_size or vcz.DEFAULT_VARIANT_CHUNK_SIZE + ), + "samples": vcz.VcfZarrDimension( + size=n, chunk_size=samples_chunk_size or vcz.DEFAULT_SAMPLE_CHUNK_SIZE + ), + "ploidy": vcz.VcfZarrDimension(size=2), + "alleles": vcz.VcfZarrDimension(size=2), + } + schema_instance = vcz.VcfZarrSchema( format_version=vcz.ZARR_SCHEMA_FORMAT_VERSION, - samples_chunk_size=samples_chunk_size, - variants_chunk_size=variants_chunk_size, + dimensions=dimensions, fields=[], ) logger.info( "Generating schema with chunks=" - f"{schema_instance.variants_chunk_size, schema_instance.samples_chunk_size}" + f"variants={dimensions['variants'].chunk_size}, " + f"samples={dimensions['samples'].chunk_size}" ) array_specs = [ @@ -86,54 +98,33 @@ def generate_schema( source="position", name="variant_position", dtype="i4", - shape=[m], dimensions=["variants"], - chunks=[schema_instance.variants_chunk_size], description=None, ), vcz.ZarrArraySpec( name="variant_allele", dtype="O", - shape=[m, 2], dimensions=["variants", "alleles"], - chunks=[schema_instance.variants_chunk_size, 2], description=None, ), vcz.ZarrArraySpec( name="call_genotype_phased", dtype="bool", - shape=[m, n], dimensions=["variants", "samples"], - chunks=[ - schema_instance.variants_chunk_size, - schema_instance.samples_chunk_size, - ], description=None, compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(), ), vcz.ZarrArraySpec( name="call_genotype", dtype="i1", - shape=[m, n, 2], dimensions=["variants", "samples", "ploidy"], - chunks=[ - schema_instance.variants_chunk_size, - schema_instance.samples_chunk_size, - 2, - ], description=None, compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(), ), vcz.ZarrArraySpec( name="call_genotype_mask", dtype="bool", - shape=[m, n, 2], dimensions=["variants", "samples", "ploidy"], - chunks=[ - schema_instance.variants_chunk_size, - schema_instance.samples_chunk_size, - 2, - ], description=None, compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(), ), diff --git a/bio2zarr/vcz.py b/bio2zarr/vcz.py index c7bc6f1b..90022e6a 100644 --- a/bio2zarr/vcz.py +++ b/bio2zarr/vcz.py @@ -15,6 +15,8 @@ logger = logging.getLogger(__name__) ZARR_SCHEMA_FORMAT_VERSION = "0.5" +DEFAULT_VARIANT_CHUNK_SIZE = 1000 +DEFAULT_SAMPLE_CHUNK_SIZE = 10_000 DEFAULT_ZARR_COMPRESSOR = numcodecs.Blosc(cname="zstd", clevel=7) DEFAULT_ZARR_COMPRESSOR_GENOTYPES = numcodecs.Blosc( cname="zstd", clevel=7, shuffle=numcodecs.Blosc.BITSHUFFLE @@ -91,12 +93,33 @@ def generate_schema(self, variants_chunk_size, samples_chunk_size, local_alleles pass +@dataclasses.dataclass +class VcfZarrDimension: + size: int + chunk_size: int = None + + def __post_init__(self): + if self.chunk_size is None: + self.chunk_size = self.size + + def asdict(self): + result = {"size": self.size} + if self.chunk_size != self.size: + result["chunk_size"] = self.chunk_size + return result + + @classmethod + def fromdict(cls, d): + return cls( + size=d["size"], + chunk_size=d.get("chunk_size", d["size"]), + ) + + @dataclasses.dataclass class ZarrArraySpec: name: str dtype: str - shape: tuple - chunks: tuple dimensions: tuple description: str compressor: dict = None @@ -107,43 +130,53 @@ def __post_init__(self): if self.name in _fixed_field_descriptions: self.description = self.description or _fixed_field_descriptions[self.name] - # Ensure these are tuples for ease of comparison and consistency - self.shape = tuple(self.shape) - self.chunks = tuple(self.chunks) self.dimensions = tuple(self.dimensions) self.filters = tuple(self.filters) if self.filters is not None else None + def get_shape(self, schema): + return schema.get_shape(self.dimensions) + + def get_chunks(self, schema): + return schema.get_chunks(self.dimensions) + + def get_chunk_nbytes(self, schema): + element_size = np.dtype(self.dtype).itemsize + chunks = self.get_chunks(schema) + shape = self.get_shape(schema) + + # Calculate actual chunk size accounting for dimension limits + items = 1 + for i, chunk_size in enumerate(chunks): + items *= min(chunk_size, shape[i]) + + # Include sizes for extra dimensions (if any) + if len(shape) > len(chunks): + for size in shape[len(chunks) :]: + items *= size + + return element_size * items + @staticmethod def from_field( vcf_field, + schema, *, - num_variants, - num_samples, - variants_chunk_size, - samples_chunk_size, array_name=None, compressor=None, filters=None, ): - shape = [num_variants] prefix = "variant_" dimensions = ["variants"] - chunks = [variants_chunk_size] if vcf_field.category == "FORMAT": prefix = "call_" - shape.append(num_samples) - chunks.append(samples_chunk_size) dimensions.append("samples") if array_name is None: array_name = prefix + vcf_field.name - # TODO make an option to add in the empty extra dimension max_number = vcf_field.max_number if (max_number > 0 and vcf_field.vcf_number in ("R", "A", "G")) or ( max_number > 1 or vcf_field.full_name == "FORMAT/LAA" ): - shape.append(max_number) - chunks.append(max_number) # TODO we should really be checking this to see if the named dimensions # are actually correct. if vcf_field.vcf_number == "R": @@ -154,42 +187,43 @@ def from_field( dimensions.append("genotypes") else: dimensions.append(f"{vcf_field.category}_{vcf_field.name}_dim") + if dimensions[-1] not in schema.dimensions: + schema.dimensions[dimensions[-1]] = VcfZarrDimension( + size=vcf_field.max_number + ) + return ZarrArraySpec( source=vcf_field.full_name, name=array_name, dtype=vcf_field.smallest_dtype(), - shape=shape, - chunks=chunks, dimensions=dimensions, description=vcf_field.description, compressor=compressor, filters=filters, ) - @property - def chunk_nbytes(self): + def chunk_nbytes(self, schema): """ Returns the nbytes for a single chunk in this array. """ items = 1 dim = 0 - for chunk_size in self.chunks: - size = min(chunk_size, self.shape[dim]) + for chunk_size in self.get_chunks(schema): + size = min(chunk_size, self.get_shape(schema)[dim]) items *= size dim += 1 # Include sizes for extra dimensions. - for size in self.shape[dim:]: + for size in self.get_shape(schema)[dim:]: items *= size dt = np.dtype(self.dtype) return items * dt.itemsize - @property - def variant_chunk_nbytes(self): + def variant_chunk_nbytes(self, schema): """ Returns the nbytes for a single variant chunk of this array. """ - chunk_items = self.chunks[0] - for size in self.shape[1:]: + chunk_items = self.get_chunks(schema)[0] + for size in self.get_shape(schema)[1:]: chunk_items *= size dt = np.dtype(self.dtype) if dt.kind == "O" and "samples" in self.dimensions: @@ -220,8 +254,7 @@ class Filter: @dataclasses.dataclass class VcfZarrSchema(core.JsonDataclass): format_version: str - samples_chunk_size: int - variants_chunk_size: int + dimensions: dict fields: list defaults: dict @@ -229,8 +262,7 @@ def __init__( self, format_version: str, fields: list, - variants_chunk_size: int = None, - samples_chunk_size: int = None, + dimensions: dict = None, defaults: dict = None, ): self.format_version = format_version @@ -241,25 +273,41 @@ def __init__( if defaults.get("filters", None) is None: defaults["filters"] = [] self.defaults = defaults - if variants_chunk_size is None: - variants_chunk_size = 1000 - self.variants_chunk_size = variants_chunk_size - if samples_chunk_size is None: - samples_chunk_size = 10_000 - self.samples_chunk_size = samples_chunk_size + if dimensions is None: + dimensions = { + "variants": VcfZarrDimension( + size=0, chunk_size=DEFAULT_VARIANT_CHUNK_SIZE + ), + "samples": VcfZarrDimension( + size=0, chunk_size=DEFAULT_SAMPLE_CHUNK_SIZE + ), + } + self.dimensions = dimensions + + def get_shape(self, dimensions): + return [self.dimensions[dim].size for dim in dimensions] + + def get_chunks(self, dimensions): + return [self.dimensions[dim].chunk_size for dim in dimensions] def validate(self): """ Checks that the schema is well-formed and within required limits. """ for field in self.fields: + for dim in field.dimensions: + if dim not in self.dimensions: + raise ValueError( + f"Dimension '{dim}' used in field '{field.name}' is " + "not defined in the schema" + ) + + chunk_nbytes = field.get_chunk_nbytes(self) # This is the Blosc max buffer size - if field.chunk_nbytes > 2147483647: - # TODO add some links to documentation here advising how to - # deal with PL values. + if chunk_nbytes > 2147483647: raise ValueError( f"Field {field.name} chunks are too large " - f"({field.chunk_nbytes} > 2**31 - 1 bytes). " + f"({chunk_nbytes} > 2**31 - 1 bytes). " "Either generate a schema and drop this field (if you don't " "need it) or reduce the variant or sample chunk sizes." ) @@ -276,8 +324,11 @@ def fromdict(d): "Zarr schema format version mismatch: " f"{d['format_version']} != {ZARR_SCHEMA_FORMAT_VERSION}" ) + ret = VcfZarrSchema(**d) ret.fields = [ZarrArraySpec(**sd) for sd in d["fields"]] + ret.dimensions = {k: VcfZarrDimension(**v) for k, v in d["dimensions"].items()} + return ret @staticmethod @@ -479,7 +530,7 @@ def init( schema.validate() partitions = VcfZarrPartition.generate_partitions( self.source.num_records, - schema.variants_chunk_size, + schema.get_chunks(["variants"])[0], target_num_partitions, max_chunks=max_variant_chunks, ) @@ -547,7 +598,7 @@ def encode_samples(self, root): shape=len(samples), dtype="str", compressor=DEFAULT_ZARR_COMPRESSOR, - chunks=(self.schema.samples_chunk_size,), + chunks=(self.schema.get_chunks(["samples"])[0],), ) array.attrs["_ARRAY_DIMENSIONS"] = ["samples"] logger.debug("Samples done") @@ -614,13 +665,13 @@ def init_array(self, root, schema, array_spec, variants_dim_size): if not zarr_utils.zarr_v3(): kwargs["dimension_separator"] = self.metadata.dimension_separator - shape = list(array_spec.shape) - # Truncate the variants dimension is max_variant_chunks was specified + shape = schema.get_shape(array_spec.dimensions) + # Truncate the variants dimension if max_variant_chunks was specified shape[0] = variants_dim_size a = root.empty( name=array_spec.name, shape=shape, - chunks=array_spec.chunks, + chunks=schema.get_chunks(array_spec.dimensions), dtype=array_spec.dtype, compressor=compressor, filters=filters, @@ -945,11 +996,13 @@ def get_max_encoding_memory(self): """ max_encoding_mem = 0 for array_spec in self.schema.fields: - max_encoding_mem = max(max_encoding_mem, array_spec.variant_chunk_nbytes) + max_encoding_mem = max( + max_encoding_mem, array_spec.variant_chunk_nbytes(self.schema) + ) gt_mem = 0 if self.has_genotypes: gt_mem = sum( - field.variant_chunk_nbytes + field.variant_chunk_nbytes(self.schema) for field in self.schema.fields if field.name.startswith("call_genotype") ) diff --git a/tests/test_cli.py b/tests/test_cli.py index c56b2422..79a18ce5 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -704,8 +704,8 @@ def test_mkschema(self, tmp_path): ) assert result.exit_code == 0 d = json.loads(result.stdout) - assert d["samples_chunk_size"] == 2 - assert d["variants_chunk_size"] == 3 + assert d["dimensions"]["samples"]["chunk_size"] == 2 + assert d["dimensions"]["variants"]["chunk_size"] == 3 @pytest.mark.parametrize("local_alleles", [False, True]) def test_mkschema_local_alleles(self, tmp_path, local_alleles): diff --git a/tests/test_icf.py b/tests/test_icf.py index 2e9275d8..479f1ab1 100644 --- a/tests/test_icf.py +++ b/tests/test_icf.py @@ -239,8 +239,8 @@ def schema(self, icf): def test_info_schemas(self, schema, name, dtype, shape, dimensions): v = schema.field_map()[name] assert v.dtype == dtype - assert tuple(v.shape) == shape assert v.dimensions == dimensions + assert tuple(schema.get_shape(dimensions)) == shape def test_info_string1(self, icf): non_missing = [v for v in icf.fields["INFO/IS1"].values if v is not None] diff --git a/tests/test_vcf_examples.py b/tests/test_vcf_examples.py index 193dfa2a..fde4c2af 100644 --- a/tests/test_vcf_examples.py +++ b/tests/test_vcf_examples.py @@ -762,7 +762,7 @@ def test_call_LAD(self, ds): [[446, -2], [393, -2], [486, -2]], ] nt.assert_array_equal(ds.call_LAD.values, call_LAD) - assert ds.call_LAD.dims == ("variants", "samples", "local_alleles") + assert ds.call_LAD.dims == ("variants", "samples", "local_alleles_AD") def test_call_LA(self, ds): # All the genotypes are 0/0 diff --git a/tests/test_vcz.py b/tests/test_vcz.py index e76196b5..63d6194e 100644 --- a/tests/test_vcz.py +++ b/tests/test_vcz.py @@ -171,23 +171,23 @@ def test_chunk_sizes(self, icf_path, samples_chunk_size, variants_chunk_size): variants_chunk_size=variants_chunk_size, samples_chunk_size=samples_chunk_size, ) - assert schema.samples_chunk_size == samples_chunk_size - assert schema.variants_chunk_size == variants_chunk_size + assert schema.dimensions["samples"].chunk_size == samples_chunk_size + assert schema.dimensions["variants"].chunk_size == variants_chunk_size found = 0 for field in schema.fields: assert field.dimensions[0] == "variants" - assert field.chunks[0] == variants_chunk_size + assert field.get_chunks(schema)[0] == variants_chunk_size if "samples" in field.dimensions: dim = field.dimensions.index("samples") - assert field.chunks[dim] == samples_chunk_size + assert field.get_chunks(schema)[dim] == samples_chunk_size found += 1 assert found > 0 def test_default_chunk_size(self, icf_path): icf = icf_mod.IntermediateColumnarFormat(icf_path) schema = icf.generate_schema() - assert schema.samples_chunk_size == 10_000 - assert schema.variants_chunk_size == 1000 + assert schema.dimensions["samples"].chunk_size == 10_000 + assert schema.dimensions["variants"].chunk_size == 1000 class TestSchemaJsonRoundTrip: @@ -298,33 +298,37 @@ class TestChunkNbytes: ) def test_example_schema(self, schema, field, value): field = schema.field_map()[field] - assert field.chunk_nbytes == value + assert field.get_chunk_nbytes(schema) == value def test_chunk_size(self, icf_path, tmp_path): store = icf_mod.IntermediateColumnarFormat(icf_path) schema = store.generate_schema(samples_chunk_size=2, variants_chunk_size=3) fields = schema.field_map() - assert fields["call_genotype"].chunk_nbytes == 3 * 2 * 2 - assert fields["variant_position"].chunk_nbytes == 3 * 4 - assert fields["variant_AC"].chunk_nbytes == 3 * 2 + assert fields["call_genotype"].get_chunk_nbytes(schema) == 3 * 2 * 2 + assert fields["variant_position"].get_chunk_nbytes(schema) == 3 * 4 + assert fields["variant_AC"].get_chunk_nbytes(schema) == 3 * 2 class TestValidateSchema: @pytest.mark.parametrize("size", [2**31, 2**31 + 1, 2**32]) def test_chunk_too_large(self, schema, size): schema = vcz.VcfZarrSchema.fromdict(schema.asdict()) + # Remove other fields as they trigger the error before + schema.fields = [schema.field_map()["variant_H2"]] field = schema.field_map()["variant_H2"] - field.shape = (size,) - field.chunks = (size,) + schema.dimensions[field.dimensions[-1]].size = size + schema.dimensions[field.dimensions[-1]].chunk_size = size with pytest.raises(ValueError, match="Field variant_H2 chunks are too large"): schema.validate() @pytest.mark.parametrize("size", [2**31 - 1, 2**30]) def test_chunk_not_too_large(self, schema, size): schema = vcz.VcfZarrSchema.fromdict(schema.asdict()) + schema.fields = [schema.field_map()["variant_H2"]] field = schema.field_map()["variant_H2"] - field.shape = (size,) - field.chunks = (size,) + schema.dimensions[field.dimensions[-1]].size = size + schema.dimensions[field.dimensions[-1]].chunk_size = size + print(schema.dimensions) schema.validate() @@ -333,15 +337,13 @@ def test_format_version(self, schema): assert schema.format_version == vcz.ZARR_SCHEMA_FORMAT_VERSION def test_chunk_size(self, schema): - assert schema.samples_chunk_size == 10000 - assert schema.variants_chunk_size == 1000 + assert schema.dimensions["samples"].chunk_size == 10000 + assert schema.dimensions["variants"].chunk_size == 1000 def test_variant_contig(self, schema): assert get_field_dict(schema, "variant_contig") == { "name": "variant_contig", "dtype": "i1", - "shape": (9,), - "chunks": (1000,), "dimensions": ("variants",), "description": "An identifier from the reference genome or an " "angle-bracketed ID string pointing to a contig in the assembly file", @@ -354,8 +356,6 @@ def test_call_genotype(self, schema): assert get_field_dict(schema, "call_genotype") == { "name": "call_genotype", "dtype": "i1", - "shape": (9, 3, 2), - "chunks": (1000, 10000, 2), "dimensions": ("variants", "samples", "ploidy"), "description": "", "source": None, @@ -373,8 +373,6 @@ def test_call_genotype_mask(self, schema): assert get_field_dict(schema, "call_genotype_mask") == { "name": "call_genotype_mask", "dtype": "bool", - "shape": (9, 3, 2), - "chunks": (1000, 10000, 2), "dimensions": ("variants", "samples", "ploidy"), "description": "", "source": None, @@ -392,8 +390,6 @@ def test_call_genotype_phased(self, schema): assert get_field_dict(schema, "call_genotype_mask") == { "name": "call_genotype_mask", "dtype": "bool", - "shape": (9, 3, 2), - "chunks": (1000, 10000, 2), "dimensions": ("variants", "samples", "ploidy"), "description": "", "source": None, @@ -411,8 +407,6 @@ def test_call_GQ(self, schema): assert get_field_dict(schema, "call_GQ") == { "name": "call_GQ", "dtype": "i1", - "shape": (9, 3), - "chunks": (1000, 10000), "dimensions": ("variants", "samples"), "description": "Genotype Quality", "source": "FORMAT/GQ", @@ -433,8 +427,6 @@ def test_call_LA(self, local_alleles_schema): "source": None, "name": "call_LA", "dtype": "i1", - "shape": (9, 3, 2), - "chunks": (1000, 10000, 2), "dimensions": ("variants", "samples", "local_alleles"), "description": ( "0-based indices into REF+ALT, indicating which alleles" @@ -819,3 +811,49 @@ def test_defaults_with_encode(self, icf_path, tmp_path): assert a.compressor.cname == "lz4" assert a.compressor.clevel == 3 assert a.compressor.shuffle == 1 + + +class TestVcfZarrDimension: + def test_dimension_initialization(self): + dim1 = vcz.VcfZarrDimension(size=100, chunk_size=20) + assert dim1.size == 100 + assert dim1.chunk_size == 20 + + # Test with only size (chunk_size should default to size) + dim2 = vcz.VcfZarrDimension(size=50) + assert dim2.size == 50 + assert dim2.chunk_size == 50 + + def test_asdict(self): + # When chunk_size equals size, it shouldn't be included in dict + dim1 = vcz.VcfZarrDimension(size=100, chunk_size=100) + assert dim1.asdict() == {"size": 100} + + # When chunk_size differs from size, it should be included in dict + dim2 = vcz.VcfZarrDimension(size=100, chunk_size=20) + assert dim2.asdict() == {"size": 100, "chunk_size": 20} + + def test_fromdict(self): + # With only size + dim1 = vcz.VcfZarrDimension.fromdict({"size": 75}) + assert dim1.size == 75 + assert dim1.chunk_size == 75 + + # With both size and chunk_size + dim2 = vcz.VcfZarrDimension.fromdict({"size": 75, "chunk_size": 25}) + assert dim2.size == 75 + assert dim2.chunk_size == 25 + + def test_json_serialization(self, icf_path): + icf = icf_mod.IntermediateColumnarFormat(icf_path) + schema = icf.generate_schema(variants_chunk_size=42, samples_chunk_size=24) + + schema_json = schema.asjson() + schema2 = vcz.VcfZarrSchema.fromjson(schema_json) + + assert schema2.dimensions["variants"].size == schema.dimensions["variants"].size + assert schema2.dimensions["variants"].chunk_size == 42 + assert schema2.dimensions["samples"].chunk_size == 24 + + assert isinstance(schema2.dimensions["variants"], vcz.VcfZarrDimension) + assert isinstance(schema2.dimensions["samples"], vcz.VcfZarrDimension)