Skip to content

Commit ff1d22b

Browse files
committed
Move code around to make common writer
1 parent 1f9274b commit ff1d22b

File tree

8 files changed

+1754
-1198
lines changed

8 files changed

+1754
-1198
lines changed

bio2zarr/core.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@ def display_size(n):
3434
return humanfriendly.format_size(n, binary=True)
3535

3636

37+
def parse_max_memory(max_memory):
38+
if max_memory is None:
39+
# Effectively unbounded
40+
return 2**63
41+
if isinstance(max_memory, str):
42+
max_memory = humanfriendly.parse_size(max_memory)
43+
logger.info(f"Set memory budget to {display_size(max_memory)}")
44+
return max_memory
45+
46+
3747
def min_int_dtype(min_value, max_value):
3848
if min_value > max_value:
3949
raise ValueError("min_value must be <= max_value")

bio2zarr/plink.py

Lines changed: 182 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,125 @@
22

33
import bed_reader
44
import humanfriendly
5-
import numcodecs
65
import numpy as np
76
import zarr
87

8+
from bio2zarr import schema, writer
99
from bio2zarr.zarr_utils import ZARR_FORMAT_KWARGS
1010

1111
from . import core
1212

1313
logger = logging.getLogger(__name__)
1414

1515

16+
def generate_schema(bed_path, variants_chunk_size=None, samples_chunk_size=None):
17+
"""
18+
Generate a schema for PLINK data based on the contents of the bed file.
19+
"""
20+
bed = bed_reader.open_bed(bed_path, num_threads=1)
21+
n = bed.iid_count
22+
m = bed.sid_count
23+
24+
if samples_chunk_size is None:
25+
samples_chunk_size = 1000
26+
if variants_chunk_size is None:
27+
variants_chunk_size = 10_000
28+
29+
logger.info(
30+
f"Generating PLINK schema with chunks={variants_chunk_size, samples_chunk_size}"
31+
)
32+
33+
ploidy = 2
34+
shape = [m, n]
35+
chunks = [variants_chunk_size, samples_chunk_size]
36+
dimensions = ["variants", "samples"]
37+
38+
array_specs = [
39+
# Sample information
40+
schema.ZarrArraySpec.new(
41+
vcf_field=None,
42+
name="sample_id",
43+
dtype="O",
44+
shape=(n,),
45+
chunks=(samples_chunk_size,),
46+
dimensions=["samples"],
47+
description="Sample identifiers",
48+
),
49+
# Variant information
50+
schema.ZarrArraySpec.new(
51+
vcf_field=None,
52+
name="variant_position",
53+
dtype=np.int32,
54+
shape=(m,),
55+
chunks=(variants_chunk_size,),
56+
dimensions=["variants"],
57+
description="The reference position",
58+
),
59+
schema.ZarrArraySpec.new(
60+
vcf_field=None,
61+
name="variant_allele",
62+
dtype="O",
63+
shape=(m, 2),
64+
chunks=(variants_chunk_size, 2),
65+
dimensions=["variants", "alleles"],
66+
description="List of the reference and alternate alleles",
67+
),
68+
# Genotype information
69+
schema.ZarrArraySpec.new(
70+
vcf_field=None,
71+
name="call_genotype_phased",
72+
dtype="bool",
73+
shape=list(shape),
74+
chunks=list(chunks),
75+
dimensions=list(dimensions),
76+
description="Boolean flag indicating if genotypes are phased",
77+
),
78+
]
79+
80+
# Add ploidy dimension for genotype arrays
81+
shape_with_ploidy = shape + [ploidy]
82+
chunks_with_ploidy = chunks + [ploidy]
83+
dimensions_with_ploidy = dimensions + ["ploidy"]
84+
85+
array_specs.extend(
86+
[
87+
schema.ZarrArraySpec.new(
88+
vcf_field=None,
89+
name="call_genotype",
90+
dtype="i1",
91+
shape=list(shape_with_ploidy),
92+
chunks=list(chunks_with_ploidy),
93+
dimensions=list(dimensions_with_ploidy),
94+
description="Genotype calls coded as allele indices",
95+
),
96+
schema.ZarrArraySpec.new(
97+
vcf_field=None,
98+
name="call_genotype_mask",
99+
dtype="bool",
100+
shape=list(shape_with_ploidy),
101+
chunks=list(chunks_with_ploidy),
102+
dimensions=list(dimensions_with_ploidy),
103+
description="Mask indicating missing genotype calls",
104+
),
105+
]
106+
)
107+
108+
# Create empty lists for VCF-specific metadata
109+
samples = [{"id": sample_id} for sample_id in bed.iid]
110+
contigs = [] # PLINK doesn't have contig information in the same way as VCF
111+
filters = [] # PLINK doesn't use filters like VCF
112+
113+
return schema.VcfZarrSchema(
114+
format_version=schema.ZARR_SCHEMA_FORMAT_VERSION,
115+
samples_chunk_size=samples_chunk_size,
116+
variants_chunk_size=variants_chunk_size,
117+
fields=array_specs,
118+
samples=samples,
119+
contigs=contigs,
120+
filters=filters,
121+
)
122+
123+
16124
def encode_genotypes_slice(bed_path, zarr_path, start, stop):
17125
# We need to count the A2 alleles here if we want to keep the
18126
# alleles reported as allele_1, allele_2. It's obvious here what
@@ -63,115 +171,88 @@ def convert(
63171
variants_chunk_size=None,
64172
samples_chunk_size=None,
65173
):
66-
bed = bed_reader.open_bed(bed_path, num_threads=1)
67-
n = bed.iid_count
68-
m = bed.sid_count
69-
logging.info(f"Scanned plink with {n} samples and {m} variants")
174+
"""
175+
Convert PLINK data to zarr format using the shared writer infrastructure.
176+
"""
177+
# Generate schema from the PLINK data
178+
plink_schema = generate_schema(
179+
bed_path,
180+
variants_chunk_size=variants_chunk_size,
181+
samples_chunk_size=samples_chunk_size,
182+
)
70183

71-
# FIXME
72-
if samples_chunk_size is None:
73-
samples_chunk_size = 1000
74-
if variants_chunk_size is None:
75-
variants_chunk_size = 10_000
184+
# Create a data source adapter for PLINK
185+
plink_adapter = PlinkDataAdapter(bed_path)
76186

77-
root = zarr.open_group(store=zarr_path, mode="w", **ZARR_FORMAT_KWARGS)
187+
# Use the general writer
188+
writer_instance = writer.GenericZarrWriter(zarr_path)
189+
writer_instance.init_from_schema(plink_schema)
78190

79-
ploidy = 2
80-
shape = [m, n]
81-
chunks = [variants_chunk_size, samples_chunk_size]
82-
dimensions = ["variants", "samples"]
83-
84-
# TODO we should be reusing some logic from vcfzarr here on laying
85-
# out the basic dataset, and using the schema generator. Currently
86-
# we're not using the best Blosc settings for genotypes here.
87-
default_compressor = numcodecs.Blosc(cname="zstd", clevel=7)
88-
89-
a = root.array(
90-
"sample_id",
91-
data=bed.iid,
92-
shape=bed.iid.shape,
93-
dtype="str",
94-
compressor=default_compressor,
95-
chunks=(samples_chunk_size,),
96-
)
97-
a.attrs["_ARRAY_DIMENSIONS"] = ["samples"]
98-
logger.debug("Encoded samples")
99-
100-
# TODO encode these in slices - but read them in one go to avoid
101-
# fetching repeatedly from bim file
102-
a = root.array(
103-
"variant_position",
104-
data=bed.bp_position,
105-
shape=bed.bp_position.shape,
106-
dtype=np.int32,
107-
compressor=default_compressor,
108-
chunks=(variants_chunk_size,),
109-
)
110-
a.attrs["_ARRAY_DIMENSIONS"] = ["variants"]
111-
logger.debug("encoded variant_position")
112-
113-
alleles = np.stack([bed.allele_1, bed.allele_2], axis=1)
114-
a = root.array(
115-
"variant_allele",
116-
data=alleles,
117-
shape=alleles.shape,
118-
dtype="str",
119-
compressor=default_compressor,
120-
chunks=(variants_chunk_size, alleles.shape[1]),
191+
# Encode data using the writer
192+
logger.info(f"Converting PLINK data to zarr at {zarr_path}")
193+
writer_instance.encode_data(
194+
plink_adapter, worker_processes=worker_processes, show_progress=show_progress
121195
)
122-
a.attrs["_ARRAY_DIMENSIONS"] = ["variants", "alleles"]
123-
logger.debug("encoded variant_allele")
124-
125-
# TODO remove this?
126-
a = root.empty(
127-
name="call_genotype_phased",
128-
dtype="bool",
129-
shape=list(shape),
130-
chunks=list(chunks),
131-
compressor=default_compressor,
132-
**ZARR_FORMAT_KWARGS,
133-
)
134-
a.attrs["_ARRAY_DIMENSIONS"] = list(dimensions)
135-
136-
shape += [ploidy]
137-
dimensions += ["ploidy"]
138-
a = root.empty(
139-
name="call_genotype",
140-
dtype="i1",
141-
shape=list(shape),
142-
chunks=list(chunks),
143-
compressor=default_compressor,
144-
**ZARR_FORMAT_KWARGS,
145-
)
146-
a.attrs["_ARRAY_DIMENSIONS"] = list(dimensions)
147-
148-
a = root.empty(
149-
name="call_genotype_mask",
150-
dtype="bool",
151-
shape=list(shape),
152-
chunks=list(chunks),
153-
compressor=default_compressor,
154-
**ZARR_FORMAT_KWARGS,
155-
)
156-
a.attrs["_ARRAY_DIMENSIONS"] = list(dimensions)
157196

158-
del bed
197+
# Finalize the zarr store
198+
writer_instance.finalise(show_progress)
199+
zarr.consolidate_metadata(zarr_path)
200+
logger.info("PLINK conversion complete")
159201

160-
num_slices = max(1, worker_processes * 4)
161-
slices = core.chunk_aligned_slices(a, num_slices)
162202

163-
total_chunks = sum(a.nchunks for _, a in root.arrays())
203+
class PlinkDataAdapter:
204+
"""
205+
Adapter class to provide PLINK data to the generic writer.
206+
"""
164207

165-
progress_config = core.ProgressConfig(
166-
total=total_chunks, title="Convert", units="chunks", show=show_progress
167-
)
168-
with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
169-
for start, stop in slices:
170-
pwm.submit(encode_genotypes_slice, bed_path, zarr_path, start, stop)
208+
def __init__(self, bed_path):
209+
self.bed_path = bed_path
210+
self.bed = bed_reader.open_bed(bed_path, num_threads=1)
211+
self.n_samples = self.bed.iid_count
212+
self.n_variants = self.bed.sid_count
171213

172-
# TODO also add atomic swap like VCF. Should be abstracted to
173-
# share basic code for setting up the variation dataset zarr
174-
zarr.consolidate_metadata(zarr_path)
214+
def get_sample_ids(self):
215+
return self.bed.iid
216+
217+
def get_variant_positions(self):
218+
return self.bed.bp_position
219+
220+
def get_variant_alleles(self):
221+
return np.stack([self.bed.allele_1, self.bed.allele_2], axis=1)
222+
223+
def get_genotypes_slice(self, start, stop):
224+
"""
225+
Read a slice of genotypes from the PLINK data.
226+
Returns a dictionary with three arrays:
227+
- genotypes: The actual genotype values
228+
- phased: Whether genotypes are phased (always False for PLINK)
229+
- mask: Which genotype values are missing
230+
"""
231+
bed_chunk = self.bed.read(slice(start, stop), dtype=np.int8).T
232+
n_variants = stop - start
233+
234+
# Create return arrays
235+
gt = np.zeros((n_variants, self.n_samples, 2), dtype=np.int8)
236+
gt_phased = np.zeros((n_variants, self.n_samples), dtype=bool)
237+
gt_mask = np.zeros((n_variants, self.n_samples, 2), dtype=bool)
238+
239+
# Convert PLINK encoding to genotype encoding
240+
# PLINK: 0=hom ref, 1=het, 2=hom alt, -127=missing
241+
# Zarr: [0,0]=hom ref, [1,0]=het, [1,1]=hom alt, [-1,-1]=missing
242+
for i, values in enumerate(bed_chunk):
243+
gt[i, values == -127] = -1
244+
gt[i, values == 2, :] = 1
245+
gt[i, values == 1, 0] = 1
246+
gt_mask[i] = gt[i] == -1
247+
248+
return {
249+
"call_genotype": gt,
250+
"call_genotype_phased": gt_phased,
251+
"call_genotype_mask": gt_mask,
252+
}
253+
254+
def close(self):
255+
del self.bed
175256

176257

177258
# FIXME do this more efficiently - currently reading the whole thing

0 commit comments

Comments
 (0)