Skip to content

Commit ff81b5c

Browse files
committed
WIP
1 parent ac46a92 commit ff81b5c

File tree

4 files changed

+112
-134
lines changed

4 files changed

+112
-134
lines changed

bio2zarr/icf.py

Lines changed: 35 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
829829
return False
830830

831831

832-
def convert_local_allele_field_types(fields):
832+
def convert_local_allele_field_types(fields, schema_instance):
833833
"""
834834
Update the specified list of fields to include the LAA field, and to convert
835835
any supported localisable fields to the L* counterpart.
@@ -842,45 +842,43 @@ def convert_local_allele_field_types(fields):
842842
"""
843843
fields_by_name = {field.name: field for field in fields}
844844
gt = fields_by_name["call_genotype"]
845-
if gt.shape[-1] != 2:
846-
raise ValueError("Local alleles only supported on diploid data")
847845

848-
# TODO check if LA is already in here
846+
if schema_instance.get_shape(["ploidy"])[0] != 2:
847+
raise ValueError("Local alleles only supported on diploid data")
849848

850-
shape = gt.shape[:-1]
851-
chunks = gt.chunks[:-1]
852849
dimensions = gt.dimensions[:-1]
853850

854851
la = vcz.ZarrArraySpec(
855852
name="call_LA",
856853
dtype="i1",
857-
shape=gt.shape,
858-
chunks=gt.chunks,
859854
dimensions=(*dimensions, "local_alleles"),
860855
description=(
861856
"0-based indices into REF+ALT, indicating which alleles"
862857
" are relevant (local) for the current sample"
863858
),
864859
)
860+
schema_instance.dimensions["local_alleles"] = {"size": 1}
861+
865862
ad = fields_by_name.get("call_AD", None)
866863
if ad is not None:
867864
# TODO check if call_LAD is in the list already
868865
ad.name = "call_LAD"
869866
ad.source = None
870-
ad.shape = (*shape, 2)
871-
ad.chunks = (*chunks, 2)
872-
ad.dimensions = (*dimensions, "local_alleles")
867+
ad.dimensions = (*dimensions, "local_alleles_AD")
873868
ad.description += " (local-alleles)"
869+
schema_instance.dimensions["local_alleles_AD"] = {"size": 2}
874870

875871
pl = fields_by_name.get("call_PL", None)
876872
if pl is not None:
877873
# TODO check if call_LPL is in the list already
878874
pl.name = "call_LPL"
879875
pl.source = None
880-
pl.shape = (*shape, 3)
881-
pl.chunks = (*chunks, 3)
882876
pl.description += " (local-alleles)"
883-
pl.dimensions = (*dimensions, "local_" + pl.dimensions[-1])
877+
pl.dimensions = (*dimensions, "local_" + pl.dimensions[-1].split("_")[-1])
878+
schema_instance.dimensions["local_" + pl.dimensions[-1].split("_")[-1]] = {
879+
"size": 3
880+
}
881+
884882
return [*fields, la]
885883

886884

@@ -1042,36 +1040,36 @@ def generate_schema(
10421040
if local_alleles is None:
10431041
local_alleles = False
10441042

1043+
dimensions = {
1044+
"variants": {"size": m, "chunk_size": variants_chunk_size or 1000},
1045+
"samples": {"size": n, "chunk_size": samples_chunk_size or 10000},
1046+
# ploidy added conditionally below
1047+
"alleles": {
1048+
"size": max(self.fields["ALT"].vcf_field.summary.max_number + 1, 2)
1049+
},
1050+
"filters": {"size": self.metadata.num_filters},
1051+
}
1052+
10451053
schema_instance = vcz.VcfZarrSchema(
10461054
format_version=vcz.ZARR_SCHEMA_FORMAT_VERSION,
1047-
samples_chunk_size=samples_chunk_size,
1048-
variants_chunk_size=variants_chunk_size,
1055+
dimensions=dimensions,
10491056
fields=[],
10501057
)
10511058

10521059
logger.info(
10531060
"Generating schema with chunks="
1054-
f"{schema_instance.variants_chunk_size, schema_instance.samples_chunk_size}"
1061+
f"variants={dimensions['variants']['chunk_size']}, "
1062+
f"samples={dimensions['samples']['chunk_size']}"
10551063
)
10561064

10571065
def spec_from_field(field, array_name=None):
10581066
return vcz.ZarrArraySpec.from_field(
10591067
field,
1060-
num_samples=n,
1061-
num_variants=m,
1062-
samples_chunk_size=schema_instance.samples_chunk_size,
1063-
variants_chunk_size=schema_instance.variants_chunk_size,
1068+
schema_instance,
10641069
array_name=array_name,
10651070
)
10661071

1067-
def fixed_field_spec(
1068-
name,
1069-
dtype,
1070-
source=None,
1071-
shape=(m,),
1072-
dimensions=("variants",),
1073-
chunks=None,
1074-
):
1072+
def fixed_field_spec(name, dtype, source=None, dimensions=("variants",)):
10751073
compressor = (
10761074
vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config()
10771075
if dtype == "bool"
@@ -1081,16 +1079,11 @@ def fixed_field_spec(
10811079
source=source,
10821080
name=name,
10831081
dtype=dtype,
1084-
shape=shape,
10851082
description="",
10861083
dimensions=dimensions,
1087-
chunks=chunks or [schema_instance.variants_chunk_size],
10881084
compressor=compressor,
10891085
)
10901086

1091-
alt_field = self.fields["ALT"]
1092-
max_alleles = alt_field.vcf_field.summary.max_number + 1
1093-
10941087
array_specs = [
10951088
fixed_field_spec(
10961089
name="variant_contig",
@@ -1099,16 +1092,12 @@ def fixed_field_spec(
10991092
fixed_field_spec(
11001093
name="variant_filter",
11011094
dtype="bool",
1102-
shape=(m, self.metadata.num_filters),
11031095
dimensions=["variants", "filters"],
1104-
chunks=(schema_instance.variants_chunk_size, self.metadata.num_filters),
11051096
),
11061097
fixed_field_spec(
11071098
name="variant_allele",
11081099
dtype="O",
1109-
shape=(m, max_alleles),
11101100
dimensions=["variants", "alleles"],
1111-
chunks=(schema_instance.variants_chunk_size, max_alleles),
11121101
),
11131102
fixed_field_spec(
11141103
name="variant_id",
@@ -1142,32 +1131,23 @@ def fixed_field_spec(
11421131

11431132
if gt_field is not None and n > 0:
11441133
ploidy = max(gt_field.summary.max_number - 1, 1)
1145-
shape = [m, n]
1146-
chunks = [
1147-
schema_instance.variants_chunk_size,
1148-
schema_instance.samples_chunk_size,
1149-
]
1150-
dimensions = ["variants", "samples"]
1134+
# Add ploidy dimension only when needed
1135+
schema_instance.dimensions["ploidy"] = {"size": ploidy}
1136+
11511137
array_specs.append(
11521138
vcz.ZarrArraySpec(
11531139
name="call_genotype_phased",
11541140
dtype="bool",
1155-
shape=list(shape),
1156-
chunks=list(chunks),
1157-
dimensions=list(dimensions),
1141+
dimensions=["variants", "samples"],
11581142
description="",
1143+
compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(),
11591144
)
11601145
)
1161-
shape += [ploidy]
1162-
chunks += [ploidy]
1163-
dimensions += ["ploidy"]
11641146
array_specs.append(
11651147
vcz.ZarrArraySpec(
11661148
name="call_genotype",
11671149
dtype=gt_field.smallest_dtype(),
1168-
shape=list(shape),
1169-
chunks=list(chunks),
1170-
dimensions=list(dimensions),
1150+
dimensions=["variants", "samples", "ploidy"],
11711151
description="",
11721152
compressor=vcz.DEFAULT_ZARR_COMPRESSOR_GENOTYPES.get_config(),
11731153
)
@@ -1176,16 +1156,14 @@ def fixed_field_spec(
11761156
vcz.ZarrArraySpec(
11771157
name="call_genotype_mask",
11781158
dtype="bool",
1179-
shape=list(shape),
1180-
chunks=list(chunks),
1181-
dimensions=list(dimensions),
1159+
dimensions=["variants", "samples", "ploidy"],
11821160
description="",
11831161
compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(),
11841162
)
11851163
)
11861164

11871165
if local_alleles:
1188-
array_specs = convert_local_allele_field_types(array_specs)
1166+
array_specs = convert_local_allele_field_types(array_specs, schema_instance)
11891167

11901168
schema_instance.fields = array_specs
11911169
return schema_instance

bio2zarr/plink.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -69,71 +69,58 @@ def generate_schema(
6969
m = self.bed.sid_count
7070
logging.info(f"Scanned plink with {n} samples and {m} variants")
7171

72+
# Define dimensions with sizes and chunk sizes
73+
dimensions = {
74+
"variants": {"size": m, "chunk_size": variants_chunk_size or 1000},
75+
"samples": {"size": n, "chunk_size": samples_chunk_size or 10000},
76+
"ploidy": {"size": 2},
77+
"alleles": {"size": 2},
78+
}
79+
7280
schema_instance = vcz.VcfZarrSchema(
7381
format_version=vcz.ZARR_SCHEMA_FORMAT_VERSION,
74-
samples_chunk_size=samples_chunk_size,
75-
variants_chunk_size=variants_chunk_size,
82+
dimensions=dimensions,
7683
fields=[],
7784
)
7885

7986
logger.info(
8087
"Generating schema with chunks="
81-
f"{schema_instance.variants_chunk_size, schema_instance.samples_chunk_size}"
88+
f"variants={dimensions['variants']['chunk_size']}, "
89+
f"samples={dimensions['samples']['chunk_size']}"
8290
)
8391

8492
array_specs = [
8593
vcz.ZarrArraySpec(
8694
source="position",
8795
name="variant_position",
8896
dtype="i4",
89-
shape=[m],
9097
dimensions=["variants"],
91-
chunks=[schema_instance.variants_chunk_size],
9298
description=None,
9399
),
94100
vcz.ZarrArraySpec(
95101
name="variant_allele",
96102
dtype="O",
97-
shape=[m, 2],
98103
dimensions=["variants", "alleles"],
99-
chunks=[schema_instance.variants_chunk_size, 2],
100104
description=None,
101105
),
102106
vcz.ZarrArraySpec(
103107
name="call_genotype_phased",
104108
dtype="bool",
105-
shape=[m, n],
106109
dimensions=["variants", "samples"],
107-
chunks=[
108-
schema_instance.variants_chunk_size,
109-
schema_instance.samples_chunk_size,
110-
],
111110
description=None,
112111
compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(),
113112
),
114113
vcz.ZarrArraySpec(
115114
name="call_genotype",
116115
dtype="i1",
117-
shape=[m, n, 2],
118116
dimensions=["variants", "samples", "ploidy"],
119-
chunks=[
120-
schema_instance.variants_chunk_size,
121-
schema_instance.samples_chunk_size,
122-
2,
123-
],
124117
description=None,
125118
compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(),
126119
),
127120
vcz.ZarrArraySpec(
128121
name="call_genotype_mask",
129122
dtype="bool",
130-
shape=[m, n, 2],
131123
dimensions=["variants", "samples", "ploidy"],
132-
chunks=[
133-
schema_instance.variants_chunk_size,
134-
schema_instance.samples_chunk_size,
135-
2,
136-
],
137124
description=None,
138125
compressor=vcz.DEFAULT_ZARR_COMPRESSOR_BOOL.get_config(),
139126
),

0 commit comments

Comments
 (0)