Skip to content

Commit 5d2bd3b

Browse files
Simplify logic around Dimension init
Centralise logic around default chunk sizes
1 parent a63054d commit 5d2bd3b

File tree

5 files changed

+140
-74
lines changed

5 files changed

+140
-74
lines changed

bio2zarr/plink.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -73,19 +73,14 @@ def generate_schema(
7373
n = self.bed.iid_count
7474
m = self.bed.sid_count
7575
logging.info(f"Scanned plink with {n} samples and {m} variants")
76-
77-
# Define dimensions with sizes and chunk sizes
78-
dimensions = {
79-
"variants": vcz.VcfZarrDimension(
80-
size=m, chunk_size=variants_chunk_size or vcz.DEFAULT_VARIANT_CHUNK_SIZE
81-
),
82-
"samples": vcz.VcfZarrDimension(
83-
size=n, chunk_size=samples_chunk_size or vcz.DEFAULT_SAMPLE_CHUNK_SIZE
84-
),
85-
"ploidy": vcz.VcfZarrDimension(size=2),
86-
"alleles": vcz.VcfZarrDimension(size=2),
87-
}
88-
76+
dimensions = vcz.standard_dimensions(
77+
variants_size=m,
78+
variants_chunk_size=variants_chunk_size,
79+
samples_size=n,
80+
samples_chunk_size=samples_chunk_size,
81+
ploidy_size=2,
82+
alleles_size=2,
83+
)
8984
schema_instance = vcz.VcfZarrSchema(
9085
format_version=vcz.ZARR_SCHEMA_FORMAT_VERSION,
9186
dimensions=dimensions,

bio2zarr/tskit.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -135,18 +135,14 @@ def generate_schema(
135135
logging.info(
136136
f"Maximum ploidy: {self.max_ploidy}, maximum alleles: {max_alleles}"
137137
)
138-
139-
dimensions = {
140-
"variants": vcz.VcfZarrDimension(
141-
size=m, chunk_size=variants_chunk_size or vcz.DEFAULT_VARIANT_CHUNK_SIZE
142-
),
143-
"samples": vcz.VcfZarrDimension(
144-
size=n, chunk_size=samples_chunk_size or vcz.DEFAULT_SAMPLE_CHUNK_SIZE
145-
),
146-
"ploidy": vcz.VcfZarrDimension(size=self.max_ploidy),
147-
"alleles": vcz.VcfZarrDimension(size=max_alleles),
148-
}
149-
138+
dimensions = vcz.standard_dimensions(
139+
variants_size=m,
140+
variants_chunk_size=variants_chunk_size,
141+
samples_size=n,
142+
samples_chunk_size=samples_chunk_size,
143+
ploidy_size=self.max_ploidy,
144+
alleles_size=max_alleles,
145+
)
150146
schema_instance = vcz.VcfZarrSchema(
151147
format_version=vcz.ZARR_SCHEMA_FORMAT_VERSION,
152148
dimensions=dimensions,

bio2zarr/vcf.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -858,8 +858,8 @@ def convert_local_allele_field_types(fields, schema_instance):
858858
" are relevant (local) for the current sample"
859859
),
860860
)
861-
schema_instance.dimensions["local_alleles"] = vcz.VcfZarrDimension(
862-
size=schema_instance.dimensions["ploidy"].size
861+
schema_instance.dimensions["local_alleles"] = vcz.VcfZarrDimension.unchunked(
862+
schema_instance.dimensions["ploidy"].size
863863
)
864864

865865
ad = fields_by_name.get("call_AD", None)
@@ -869,7 +869,9 @@ def convert_local_allele_field_types(fields, schema_instance):
869869
ad.source = None
870870
ad.dimensions = (*dimensions, "local_alleles_AD")
871871
ad.description += " (local-alleles)"
872-
schema_instance.dimensions["local_alleles_AD"] = vcz.VcfZarrDimension(size=2)
872+
schema_instance.dimensions["local_alleles_AD"] = vcz.VcfZarrDimension.unchunked(
873+
2
874+
)
873875

874876
pl = fields_by_name.get("call_PL", None)
875877
if pl is not None:
@@ -879,7 +881,7 @@ def convert_local_allele_field_types(fields, schema_instance):
879881
pl.description += " (local-alleles)"
880882
pl.dimensions = (*dimensions, "local_" + pl.dimensions[-1].split("_")[-1])
881883
schema_instance.dimensions["local_" + pl.dimensions[-1].split("_")[-1]] = (
882-
vcz.VcfZarrDimension(size=3)
884+
vcz.VcfZarrDimension.unchunked(3)
883885
)
884886

885887
return [*fields, la]

bio2zarr/vcz.py

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -102,28 +102,18 @@ def generate_schema(self, variants_chunk_size, samples_chunk_size, local_alleles
102102
@dataclasses.dataclass
103103
class VcfZarrDimension:
104104
size: int
105-
chunk_size: int = None
106-
107-
def __post_init__(self):
108-
if self.chunk_size is None:
109-
self.chunk_size = self.size
105+
chunk_size: int
110106

111107
def asdict(self):
112-
result = {"size": self.size}
113-
if self.chunk_size != self.size:
114-
result["chunk_size"] = self.chunk_size
115-
return result
108+
return dataclasses.asdict(self)
116109

117110
@classmethod
118111
def fromdict(cls, d):
119-
return cls(
120-
size=d["size"],
121-
chunk_size=d.get("chunk_size", d["size"]),
122-
)
112+
return cls(**d)
123113

124114
@classmethod
125115
def unchunked(cls, size):
126-
return cls(size, size)
116+
return cls(size, max(size, 1))
127117

128118

129119
def standard_dimensions(
@@ -153,7 +143,8 @@ def standard_dimensions(
153143

154144
if alleles_size is not None:
155145
dimensions["alleles"] = VcfZarrDimension.unchunked(alleles_size)
156-
dimensions["alt_alleles"] = VcfZarrDimension.unchunked(alleles_size - 1)
146+
if alleles_size > 1:
147+
dimensions["alt_alleles"] = VcfZarrDimension.unchunked(alleles_size - 1)
157148

158149
if filters_size is not None:
159150
dimensions["filters"] = VcfZarrDimension.unchunked(filters_size)
@@ -255,8 +246,8 @@ def from_field(
255246
elif max_number > 1 or vcf_field.full_name == "FORMAT/LAA":
256247
dimensions.append(f"{vcf_field.category}_{vcf_field.name}_dim")
257248
if dimensions[-1] not in schema.dimensions:
258-
schema.dimensions[dimensions[-1]] = VcfZarrDimension(
259-
size=vcf_field.max_number
249+
schema.dimensions[dimensions[-1]] = VcfZarrDimension.unchunked(
250+
vcf_field.max_number
260251
)
261252

262253
return ZarrArraySpec(
@@ -329,7 +320,7 @@ def __init__(
329320
self,
330321
format_version: str,
331322
fields: list,
332-
dimensions: dict = None,
323+
dimensions: dict,
333324
defaults: dict = None,
334325
):
335326
self.format_version = format_version
@@ -340,15 +331,6 @@ def __init__(
340331
if defaults.get("filters", None) is None:
341332
defaults["filters"] = []
342333
self.defaults = defaults
343-
if dimensions is None:
344-
dimensions = {
345-
"variants": VcfZarrDimension(
346-
size=0, chunk_size=DEFAULT_VARIANT_CHUNK_SIZE
347-
),
348-
"samples": VcfZarrDimension(
349-
size=0, chunk_size=DEFAULT_SAMPLE_CHUNK_SIZE
350-
),
351-
}
352334
self.dimensions = dimensions
353335

354336
def get_shape(self, dimensions):
@@ -394,7 +376,9 @@ def fromdict(d):
394376

395377
ret = VcfZarrSchema(**d)
396378
ret.fields = [ZarrArraySpec(**sd) for sd in d["fields"]]
397-
ret.dimensions = {k: VcfZarrDimension(**v) for k, v in d["dimensions"].items()}
379+
ret.dimensions = {
380+
k: VcfZarrDimension.fromdict(v) for k, v in d["dimensions"].items()
381+
}
398382

399383
return ret
400384

tests/test_vcz.py

Lines changed: 106 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,7 @@ def test_custom_defaults(self, icf_path):
751751
schema = vcz.VcfZarrSchema(
752752
format_version=vcz.ZARR_SCHEMA_FORMAT_VERSION,
753753
fields=[],
754+
dimensions={},
754755
defaults=custom_defaults,
755756
)
756757

@@ -761,6 +762,7 @@ def test_partial_defaults(self, icf_path):
761762
schema1 = vcz.VcfZarrSchema(
762763
format_version=vcz.ZARR_SCHEMA_FORMAT_VERSION,
763764
fields=[],
765+
dimensions={},
764766
defaults={"compressor": {"id": "blosc", "cname": "zlib", "clevel": 5}},
765767
)
766768
assert schema1.defaults["compressor"] == {
@@ -774,6 +776,7 @@ def test_partial_defaults(self, icf_path):
774776
schema2 = vcz.VcfZarrSchema(
775777
format_version=vcz.ZARR_SCHEMA_FORMAT_VERSION,
776778
fields=[],
779+
dimensions={},
777780
defaults={"filters": [{"id": "delta"}]},
778781
)
779782
assert (
@@ -819,27 +822,21 @@ def test_dimension_initialization(self):
819822
assert dim1.size == 100
820823
assert dim1.chunk_size == 20
821824

822-
# Test with only size (chunk_size should default to size)
823-
dim2 = vcz.VcfZarrDimension(size=50)
824-
assert dim2.size == 50
825-
assert dim2.chunk_size == 50
825+
def test_unchunked(self):
826+
dim = vcz.VcfZarrDimension.unchunked(50)
827+
assert dim.size == 50
828+
assert dim.chunk_size == 50
826829

827-
def test_asdict(self):
828-
# When chunk_size equals size, it shouldn't be included in dict
829-
dim1 = vcz.VcfZarrDimension(size=100, chunk_size=100)
830-
assert dim1.asdict() == {"size": 100}
830+
def test_unchunked_zero_size(self):
831+
dim = vcz.VcfZarrDimension.unchunked(0)
832+
assert dim.size == 0
833+
assert dim.chunk_size == 1
831834

832-
# When chunk_size differs from size, it should be included in dict
833-
dim2 = vcz.VcfZarrDimension(size=100, chunk_size=20)
834-
assert dim2.asdict() == {"size": 100, "chunk_size": 20}
835+
def test_asdict(self):
836+
dim1 = vcz.VcfZarrDimension(size=100, chunk_size=101)
837+
assert dim1.asdict() == {"size": 100, "chunk_size": 101}
835838

836839
def test_fromdict(self):
837-
# With only size
838-
dim1 = vcz.VcfZarrDimension.fromdict({"size": 75})
839-
assert dim1.size == 75
840-
assert dim1.chunk_size == 75
841-
842-
# With both size and chunk_size
843840
dim2 = vcz.VcfZarrDimension.fromdict({"size": 75, "chunk_size": 25})
844841
assert dim2.size == 75
845842
assert dim2.chunk_size == 25
@@ -898,6 +895,98 @@ def test_max_number_exceeds_dimension_size(
898895
vcz.ZarrArraySpec.from_field(vcf_field, schema)
899896

900897

898+
class TestStandardDimensions:
899+
@pytest.mark.parametrize(
900+
("size", "chunk_size", "expected_chunk_size"),
901+
[
902+
(0, None, 1),
903+
(0, 100, 100),
904+
(1, 1, 1),
905+
(1, None, 1),
906+
(1, 10, 10),
907+
(1_001, None, 1_000),
908+
(10**9, None, 1_000),
909+
(999, None, 999),
910+
(1, 100_000, 100_000),
911+
],
912+
)
913+
def test_variants(self, size, chunk_size, expected_chunk_size):
914+
dims = vcz.standard_dimensions(
915+
variants_size=size, variants_chunk_size=chunk_size, samples_size=0
916+
)
917+
assert dims["variants"] == vcz.VcfZarrDimension(size, expected_chunk_size)
918+
919+
@pytest.mark.parametrize(
920+
("size", "chunk_size", "expected_chunk_size"),
921+
[
922+
(0, None, 1),
923+
(0, 100, 100),
924+
(1, 1, 1),
925+
(1, None, 1),
926+
(1, 10, 10),
927+
(10_001, None, 10_000),
928+
(10**9, None, 10_000),
929+
(9_999, None, 9_999),
930+
(1, 100_000, 100_000),
931+
],
932+
)
933+
def test_samples(self, size, chunk_size, expected_chunk_size):
934+
dims = vcz.standard_dimensions(
935+
variants_size=0, samples_size=size, samples_chunk_size=chunk_size
936+
)
937+
assert dims["samples"] == vcz.VcfZarrDimension(size, expected_chunk_size)
938+
939+
@pytest.mark.parametrize(
940+
("kwargs", "expected"),
941+
[
942+
(
943+
{"variants_size": 1, "samples_size": 1, "alleles_size": 2},
944+
{
945+
"variants": {"size": 1, "chunk_size": 1},
946+
"samples": {"size": 1, "chunk_size": 1},
947+
"alleles": {"size": 2, "chunk_size": 2},
948+
"alt_alleles": {"size": 1, "chunk_size": 1},
949+
},
950+
),
951+
(
952+
{"variants_size": 0, "samples_size": 1, "alleles_size": 1},
953+
{
954+
"variants": {"size": 0, "chunk_size": 1},
955+
"samples": {"size": 1, "chunk_size": 1},
956+
"alleles": {"size": 1, "chunk_size": 1},
957+
},
958+
),
959+
(
960+
{"variants_size": 0, "samples_size": 1, "alleles_size": 0},
961+
{
962+
"variants": {"size": 0, "chunk_size": 1},
963+
"samples": {"size": 1, "chunk_size": 1},
964+
"alleles": {"size": 0, "chunk_size": 1},
965+
},
966+
),
967+
(
968+
{"variants_size": 0, "samples_size": 1, "filters_size": 2},
969+
{
970+
"variants": {"size": 0, "chunk_size": 1},
971+
"samples": {"size": 1, "chunk_size": 1},
972+
"filters": {"size": 2, "chunk_size": 2},
973+
},
974+
),
975+
],
976+
)
977+
def test_examples(self, kwargs, expected):
978+
dims = {k: v.asdict() for k, v in vcz.standard_dimensions(**kwargs).items()}
979+
assert dims == expected
980+
981+
@pytest.mark.parametrize("field", ["ploidy", "genotypes"])
982+
@pytest.mark.parametrize("size", [0, 1, 2])
983+
def test_simple_fields(self, field, size):
984+
dims = vcz.standard_dimensions(
985+
samples_size=1, variants_size=1, **{f"{field}_size": size}
986+
)
987+
assert dims[field].asdict() == {"size": size, "chunk_size": max(1, size)}
988+
989+
901990
def test_create_index_errors(tmp_path):
902991
root = zarr.open(tmp_path)
903992
root["foobar"] = np.array([1, 2, 3])

0 commit comments

Comments
 (0)