Skip to content

Commit d4d7246

Browse files
committed
Move schema generation to source classes
1 parent 9655012 commit d4d7246

File tree

6 files changed

+250
-256
lines changed

6 files changed

+250
-256
lines changed

bio2zarr/plink.py

Lines changed: 79 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -49,83 +49,82 @@ def iter_genotypes(self, shape, start, stop):
4949
gt[values == 2] = [0, 0] # Homozygous REF (0 in PLINK)
5050
yield gt, phased
5151

52-
53-
def generate_schema(
54-
bed,
55-
variants_chunk_size=None,
56-
samples_chunk_size=None,
57-
):
58-
n = bed.iid_count
59-
m = bed.sid_count
60-
logging.info(f"Scanned plink with {n} samples and {m} variants")
61-
62-
# FIXME
63-
if samples_chunk_size is None:
64-
samples_chunk_size = 1000
65-
if variants_chunk_size is None:
66-
variants_chunk_size = 10_000
67-
68-
logger.info(
69-
f"Generating schema with chunks={variants_chunk_size, samples_chunk_size}"
70-
)
71-
72-
array_specs = [
73-
schema.ZarrArraySpec.new(
74-
vcf_field="position",
75-
name="variant_position",
76-
dtype="i4",
77-
shape=[m],
78-
dimensions=["variants"],
79-
chunks=[variants_chunk_size],
80-
description=None,
81-
),
82-
schema.ZarrArraySpec.new(
83-
vcf_field=None,
84-
name="variant_allele",
85-
dtype="O",
86-
shape=[m, 2],
87-
dimensions=["variants", "alleles"],
88-
chunks=[variants_chunk_size, 2],
89-
description=None,
90-
),
91-
schema.ZarrArraySpec.new(
92-
vcf_field=None,
93-
name="call_genotype_phased",
94-
dtype="bool",
95-
shape=[m, n],
96-
dimensions=["variants", "samples"],
97-
chunks=[variants_chunk_size, samples_chunk_size],
98-
description=None,
99-
),
100-
schema.ZarrArraySpec.new(
101-
vcf_field=None,
102-
name="call_genotype",
103-
dtype="i1",
104-
shape=[m, n, 2],
105-
dimensions=["variants", "samples", "ploidy"],
106-
chunks=[variants_chunk_size, samples_chunk_size, 2],
107-
description=None,
108-
),
109-
schema.ZarrArraySpec.new(
110-
vcf_field=None,
111-
name="call_genotype_mask",
112-
dtype="bool",
113-
shape=[m, n, 2],
114-
dimensions=["variants", "samples", "ploidy"],
115-
chunks=[variants_chunk_size, samples_chunk_size, 2],
116-
description=None,
117-
),
118-
]
119-
120-
return schema.VcfZarrSchema(
121-
format_version=schema.ZARR_SCHEMA_FORMAT_VERSION,
122-
samples_chunk_size=samples_chunk_size,
123-
variants_chunk_size=variants_chunk_size,
124-
fields=array_specs,
125-
samples=[schema.Sample(id=sample) for sample in bed.iid],
126-
contigs=[],
127-
filters=[],
128-
)
52+
def generate_schema(
53+
self,
54+
variants_chunk_size=None,
55+
samples_chunk_size=None,
56+
):
57+
n = self.bed.iid_count
58+
m = self.bed.sid_count
59+
logging.info(f"Scanned plink with {n} samples and {m} variants")
60+
61+
# FIXME
62+
if samples_chunk_size is None:
63+
samples_chunk_size = 1000
64+
if variants_chunk_size is None:
65+
variants_chunk_size = 10_000
66+
67+
logger.info(
68+
f"Generating schema with chunks={variants_chunk_size, samples_chunk_size}"
69+
)
70+
71+
array_specs = [
72+
schema.ZarrArraySpec.new(
73+
vcf_field="position",
74+
name="variant_position",
75+
dtype="i4",
76+
shape=[m],
77+
dimensions=["variants"],
78+
chunks=[variants_chunk_size],
79+
description=None,
80+
),
81+
schema.ZarrArraySpec.new(
82+
vcf_field=None,
83+
name="variant_allele",
84+
dtype="O",
85+
shape=[m, 2],
86+
dimensions=["variants", "alleles"],
87+
chunks=[variants_chunk_size, 2],
88+
description=None,
89+
),
90+
schema.ZarrArraySpec.new(
91+
vcf_field=None,
92+
name="call_genotype_phased",
93+
dtype="bool",
94+
shape=[m, n],
95+
dimensions=["variants", "samples"],
96+
chunks=[variants_chunk_size, samples_chunk_size],
97+
description=None,
98+
),
99+
schema.ZarrArraySpec.new(
100+
vcf_field=None,
101+
name="call_genotype",
102+
dtype="i1",
103+
shape=[m, n, 2],
104+
dimensions=["variants", "samples", "ploidy"],
105+
chunks=[variants_chunk_size, samples_chunk_size, 2],
106+
description=None,
107+
),
108+
schema.ZarrArraySpec.new(
109+
vcf_field=None,
110+
name="call_genotype_mask",
111+
dtype="bool",
112+
shape=[m, n, 2],
113+
dimensions=["variants", "samples", "ploidy"],
114+
chunks=[variants_chunk_size, samples_chunk_size, 2],
115+
description=None,
116+
),
117+
]
118+
119+
return schema.VcfZarrSchema(
120+
format_version=schema.ZARR_SCHEMA_FORMAT_VERSION,
121+
samples_chunk_size=samples_chunk_size,
122+
variants_chunk_size=variants_chunk_size,
123+
fields=array_specs,
124+
samples=[schema.Sample(id=sample) for sample in self.bed.iid],
125+
contigs=[],
126+
filters=[],
127+
)
129128

130129

131130
def convert(
@@ -137,9 +136,8 @@ def convert(
137136
worker_processes=1,
138137
show_progress=False,
139138
):
140-
bed = bed_reader.open_bed(bed_path, num_threads=1)
141-
schema_instance = generate_schema(
142-
bed,
139+
plink_format = PlinkFormat(bed_path)
140+
schema_instance = plink_format.generate_schema(
143141
variants_chunk_size=variants_chunk_size,
144142
samples_chunk_size=samples_chunk_size,
145143
)
@@ -148,7 +146,7 @@ def convert(
148146
# Rough heuristic to split work up enough to keep utilisation high
149147
target_num_partitions = max(1, worker_processes * 4)
150148
vzw.init(
151-
PlinkFormat(bed_path),
149+
plink_format,
152150
target_num_partitions=target_num_partitions,
153151
schema=schema_instance,
154152
)

bio2zarr/vcf2zarr/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
encode_finalise,
1212
encode_init,
1313
encode_partition,
14-
generate_schema,
1514
inspect,
1615
mkschema,
1716
)
@@ -33,6 +32,5 @@
3332
"encode_partition",
3433
"inspect",
3534
"mkschema",
36-
"generate_schema",
3735
"verify",
3836
]

bio2zarr/vcf2zarr/icf.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,161 @@ def iter_genotypes(self, shape, start, stop):
969969
sanitised_phased = sanitise_value_int_1d(shape[:-1], phased)
970970
yield sanitised_genotypes, sanitised_phased
971971

972+
def generate_schema(
973+
self, variants_chunk_size=None, samples_chunk_size=None, local_alleles=None
974+
):
975+
# Import schema here to avoid circular import
976+
from bio2zarr import schema
977+
978+
m = self.num_records
979+
n = self.num_samples
980+
if samples_chunk_size is None:
981+
samples_chunk_size = 10_000
982+
if variants_chunk_size is None:
983+
variants_chunk_size = 1000
984+
if local_alleles is None:
985+
local_alleles = False
986+
logger.info(
987+
f"Generating schema with chunks={variants_chunk_size, samples_chunk_size}"
988+
)
989+
990+
def spec_from_field(field, array_name=None):
991+
return schema.ZarrArraySpec.from_field(
992+
field,
993+
num_samples=n,
994+
num_variants=m,
995+
samples_chunk_size=samples_chunk_size,
996+
variants_chunk_size=variants_chunk_size,
997+
array_name=array_name,
998+
)
999+
1000+
def fixed_field_spec(
1001+
name,
1002+
dtype,
1003+
vcf_field=None,
1004+
shape=(m,),
1005+
dimensions=("variants",),
1006+
chunks=None,
1007+
):
1008+
return schema.ZarrArraySpec.new(
1009+
vcf_field=vcf_field,
1010+
name=name,
1011+
dtype=dtype,
1012+
shape=shape,
1013+
description="",
1014+
dimensions=dimensions,
1015+
chunks=chunks or [variants_chunk_size],
1016+
)
1017+
1018+
alt_field = self.fields["ALT"]
1019+
max_alleles = alt_field.vcf_field.summary.max_number + 1
1020+
1021+
array_specs = [
1022+
fixed_field_spec(
1023+
name="variant_contig",
1024+
dtype=core.min_int_dtype(0, self.metadata.num_contigs),
1025+
),
1026+
fixed_field_spec(
1027+
name="variant_filter",
1028+
dtype="bool",
1029+
shape=(m, self.metadata.num_filters),
1030+
dimensions=["variants", "filters"],
1031+
chunks=(variants_chunk_size, self.metadata.num_filters),
1032+
),
1033+
fixed_field_spec(
1034+
name="variant_allele",
1035+
dtype="O",
1036+
shape=(m, max_alleles),
1037+
dimensions=["variants", "alleles"],
1038+
chunks=(variants_chunk_size, max_alleles),
1039+
),
1040+
fixed_field_spec(
1041+
name="variant_id",
1042+
dtype="O",
1043+
),
1044+
fixed_field_spec(
1045+
name="variant_id_mask",
1046+
dtype="bool",
1047+
),
1048+
]
1049+
name_map = {field.full_name: field for field in self.metadata.fields}
1050+
1051+
# Only three of the fixed fields have a direct one-to-one mapping.
1052+
array_specs.extend(
1053+
[
1054+
spec_from_field(name_map["QUAL"], array_name="variant_quality"),
1055+
spec_from_field(name_map["POS"], array_name="variant_position"),
1056+
spec_from_field(name_map["rlen"], array_name="variant_length"),
1057+
]
1058+
)
1059+
array_specs.extend(
1060+
[spec_from_field(field) for field in self.metadata.info_fields]
1061+
)
1062+
1063+
gt_field = None
1064+
for field in self.metadata.format_fields:
1065+
if field.name == "GT":
1066+
gt_field = field
1067+
continue
1068+
array_specs.append(spec_from_field(field))
1069+
1070+
if gt_field is not None and n > 0:
1071+
ploidy = max(gt_field.summary.max_number - 1, 1)
1072+
shape = [m, n]
1073+
chunks = [variants_chunk_size, samples_chunk_size]
1074+
dimensions = ["variants", "samples"]
1075+
array_specs.append(
1076+
schema.ZarrArraySpec.new(
1077+
vcf_field=None,
1078+
name="call_genotype_phased",
1079+
dtype="bool",
1080+
shape=list(shape),
1081+
chunks=list(chunks),
1082+
dimensions=list(dimensions),
1083+
description="",
1084+
)
1085+
)
1086+
shape += [ploidy]
1087+
chunks += [ploidy]
1088+
dimensions += ["ploidy"]
1089+
array_specs.append(
1090+
schema.ZarrArraySpec.new(
1091+
vcf_field=None,
1092+
name="call_genotype",
1093+
dtype=gt_field.smallest_dtype(),
1094+
shape=list(shape),
1095+
chunks=list(chunks),
1096+
dimensions=list(dimensions),
1097+
description="",
1098+
)
1099+
)
1100+
array_specs.append(
1101+
schema.ZarrArraySpec.new(
1102+
vcf_field=None,
1103+
name="call_genotype_mask",
1104+
dtype="bool",
1105+
shape=list(shape),
1106+
chunks=list(chunks),
1107+
dimensions=list(dimensions),
1108+
description="",
1109+
)
1110+
)
1111+
1112+
if local_alleles:
1113+
from bio2zarr.vcf2zarr.vcz import convert_local_allele_field_types
1114+
1115+
array_specs = convert_local_allele_field_types(array_specs)
1116+
1117+
return schema.VcfZarrSchema(
1118+
format_version=schema.ZARR_SCHEMA_FORMAT_VERSION,
1119+
samples_chunk_size=samples_chunk_size,
1120+
variants_chunk_size=variants_chunk_size,
1121+
fields=array_specs,
1122+
samples=self.metadata.samples,
1123+
contigs=self.metadata.contigs,
1124+
filters=self.metadata.filters,
1125+
)
1126+
9721127

9731128
@dataclasses.dataclass
9741129
class IcfPartitionMetadata(core.JsonDataclass):

0 commit comments

Comments
 (0)