Skip to content

Commit 87f3658

Browse files
Merge pull request #56 from jeromekelleher/plink-fixups
Plink fixups
2 parents 95a265f + 4da9bca commit 87f3658

File tree

5 files changed

+130
-52
lines changed

5 files changed

+130
-52
lines changed

bio2zarr/cli.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,14 @@ def vcf2zarr():
155155
@click.argument("in_path", type=click.Path())
156156
@click.argument("out_path", type=click.Path())
157157
@worker_processes
158-
@click.option("--chunk-width", type=int, default=None)
159-
@click.option("--chunk-length", type=int, default=None)
160-
def convert_plink(in_path, out_path, worker_processes, chunk_width, chunk_length):
158+
@verbose
159+
@chunk_length
160+
@chunk_width
161+
def convert_plink(in_path, out_path, verbose, worker_processes, chunk_length, chunk_width):
161162
"""
162163
In development; DO NOT USE!
163164
"""
165+
setup_logging(verbose)
164166
plink.convert(
165167
in_path,
166168
out_path,

bio2zarr/core.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ def next_buffer_row(self):
8888
return row
8989

9090
def flush(self):
91-
# TODO just move sync_flush_array in here
9291
if self.buffer_row != 0:
9392
if len(self.array.chunks) <= 1:
9493
sync_flush_1d_array(

bio2zarr/plink.py

Lines changed: 82 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,53 @@
1+
import logging
2+
3+
import humanfriendly
14
import numpy as np
25
import zarr
36
import bed_reader
47

58
from . import core
69

710

8-
def encode_bed_partition_genotypes(
9-
bed_path, zarr_path, start_variant, end_variant, encoder_threads=8
10-
):
11-
bed = bed_reader.open_bed(bed_path, num_threads=1)
11+
logger = logging.getLogger(__name__)
12+
1213

14+
def encode_genotypes_slice(bed_path, zarr_path, start, stop):
15+
bed = bed_reader.open_bed(bed_path, num_threads=1)
1316
store = zarr.DirectoryStore(zarr_path)
1417
root = zarr.group(store=store)
15-
gt = core.BufferedArray(root["call_genotype"])
16-
gt_mask = core.BufferedArray(root["call_genotype_mask"])
17-
gt_phased = core.BufferedArray(root["call_genotype_phased"])
18+
gt = core.BufferedArray(root["call_genotype"], start)
19+
gt_mask = core.BufferedArray(root["call_genotype_mask"], start)
20+
gt_phased = core.BufferedArray(root["call_genotype_phased"], start)
1821
chunk_length = gt.array.chunks[0]
19-
assert start_variant % chunk_length == 0
20-
21-
buffered_arrays = [gt, gt_phased, gt_mask]
22-
23-
with core.ThreadedZarrEncoder(buffered_arrays, encoder_threads) as te:
24-
start = start_variant
25-
while start < end_variant:
26-
stop = min(start + chunk_length, end_variant)
27-
bed_chunk = bed.read(index=slice(start, stop), dtype="int8").T
28-
# Note could do this without iterating over rows, but it's a bit
29-
# simpler and the bottleneck is in the encoding step anyway. It's
30-
# also nice to have updates on the progress monitor.
31-
for values in bed_chunk:
32-
j = te.next_buffer_row()
33-
dest = gt.buff[j]
34-
dest[values == -127] = -1
35-
dest[values == 2] = 1
36-
dest[values == 1, 0] = 1
37-
gt_phased.buff[j] = False
38-
gt_mask.buff[j] = dest == -1
39-
core.update_progress(1)
40-
start = stop
22+
n = gt.array.shape[1]
23+
assert start % chunk_length == 0
24+
25+
logger.debug(f"Reading slice {start}:{stop}")
26+
chunk_start = start
27+
while chunk_start < stop:
28+
chunk_stop = min(chunk_start + chunk_length, stop)
29+
logger.debug(f"Reading bed slice {chunk_start}:{chunk_stop}")
30+
bed_chunk = bed.read(slice(chunk_start, chunk_stop), dtype=np.int8).T
31+
logger.debug(f"Got bed slice {humanfriendly.format_size(bed_chunk.nbytes)}")
32+
# Probably should do this without iterating over rows, but it's a bit
33+
# simpler and lines up better with the array buffering API. The bottleneck
34+
# is in the encoding anyway.
35+
for values in bed_chunk:
36+
j = gt.next_buffer_row()
37+
g = np.zeros_like(gt.buff[j])
38+
g[values == -127] = -1
39+
g[values == 2] = 1
40+
g[values == 1, 0] = 1
41+
gt.buff[j] = g
42+
j = gt_phased.next_buffer_row()
43+
gt_phased.buff[j] = False
44+
j = gt_mask.next_buffer_row()
45+
gt_mask.buff[j] = gt.buff[j] == -1
46+
chunk_start = chunk_stop
47+
gt.flush()
48+
gt_phased.flush()
49+
gt_mask.flush()
50+
logger.debug(f"GT slice {start}:{stop} done")
4151

4252

4353
def convert(
@@ -53,6 +63,7 @@ def convert(
5363
n = bed.iid_count
5464
m = bed.sid_count
5565
del bed
66+
logging.info(f"Scanned plink with {n} samples and {m} variants")
5667

5768
# FIXME
5869
if chunk_width is None:
@@ -81,7 +92,7 @@ def convert(
8192
dimensions += ["ploidy"]
8293
a = root.empty(
8394
"call_genotype",
84-
dtype="i8",
95+
dtype="i1",
8596
shape=list(shape),
8697
chunks=list(chunks),
8798
compressor=core.default_compressor,
@@ -97,22 +108,52 @@ def convert(
97108
)
98109
a.attrs["_ARRAY_DIMENSIONS"] = list(dimensions)
99110

100-
chunks_per_future = 2 # FIXME - make a parameter
101-
start = 0
102-
partitions = []
103-
while start < m:
104-
stop = min(m, start + chunk_length * chunks_per_future)
105-
partitions.append((start, stop))
106-
start = stop
107-
assert start == m
111+
num_slices = max(1, worker_processes * 4)
112+
slices = core.chunk_aligned_slices(a, num_slices)
113+
114+
total_chunks = sum(a.nchunks for a in root.values())
108115

109116
progress_config = core.ProgressConfig(
110-
total=m, title="Convert", units="vars", show=show_progress
117+
total=total_chunks, title="Convert", units="chunks", show=show_progress
111118
)
112119
with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
113-
for start, end in partitions:
114-
pwm.submit(encode_bed_partition_genotypes, bed_path, zarr_path, start, end)
120+
for start, stop in slices:
121+
pwm.submit(encode_genotypes_slice, bed_path, zarr_path, start, stop)
115122

116123
# TODO also add atomic swap like VCF. Should be abstracted to
117124
# share basic code for setting up the variation dataset zarr
118125
zarr.consolidate_metadata(zarr_path)
126+
127+
128+
# FIXME do this more efficiently - currently reading the whole thing
129+
# in for convenience, and also comparing call-by-call
130+
def validate(bed_path, zarr_path):
131+
store = zarr.DirectoryStore(zarr_path)
132+
root = zarr.group(store=store)
133+
call_genotype = root["call_genotype"][:]
134+
135+
bed = bed_reader.open_bed(bed_path, num_threads=1)
136+
137+
assert call_genotype.shape[0] == bed.sid_count
138+
assert call_genotype.shape[1] == bed.iid_count
139+
bed_genotypes = bed.read(dtype="int8").T
140+
assert call_genotype.shape[0] == bed_genotypes.shape[0]
141+
assert call_genotype.shape[1] == bed_genotypes.shape[1]
142+
assert call_genotype.shape[2] == 2
143+
144+
row_id = 0
145+
for bed_row, zarr_row in zip(bed_genotypes, call_genotype):
146+
# print("ROW", row_id)
147+
# print(bed_row, zarr_row)
148+
row_id += 1
149+
for bed_call, zarr_call in zip(bed_row, zarr_row):
150+
if bed_call == -127:
151+
assert list(zarr_call) == [-1, -1]
152+
elif bed_call == 0:
153+
assert list(zarr_call) == [0, 0]
154+
elif bed_call == 1:
155+
assert list(zarr_call) == [1, 0]
156+
elif bed_call == 2:
157+
assert list(zarr_call) == [1, 1]
158+
else: # pragma no cover
159+
assert False

tests/test_plink.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@ class TestSmallExample:
1212
@pytest.fixture(scope="class")
1313
def ds(self, tmp_path_factory):
1414
path = "tests/data/plink/plink_sim_10s_100v_10pmiss.bed"
15-
out = tmp_path_factory.mktemp("data") / "example.vcf.zarr"
15+
out = tmp_path_factory.mktemp("data") / "example.plink.zarr"
1616
plink.convert(path, out)
1717
return sg.load_dataset(out)
1818

19-
2019
@pytest.mark.xfail
20+
# FIXME I'm not sure these are the correct genotypes here, at least
21+
# the test isn't passing and others are
2122
def test_genotypes(self, ds):
2223
# Validate a few randomly selected individual calls
2324
# (spanning all possible states for a call)
@@ -56,7 +57,7 @@ def test_genotypes(self, ds):
5657
# FIXME not working
5758
nt.assert_array_equal(actual, expected)
5859

59-
@pytest.mark.xfail
60+
# @pytest.mark.xfail
6061
@pytest.mark.parametrize(
6162
["chunk_length", "chunk_width"],
6263
[
@@ -73,11 +74,46 @@ def test_chunk_size(
7374
):
7475
path = "tests/data/plink/plink_sim_10s_100v_10pmiss.bed"
7576
out = tmp_path / "example.zarr"
76-
plink.convert(path, out, chunk_length=chunk_length, chunk_width=chunk_width,
77-
worker_processes=worker_processes)
77+
plink.convert(
78+
path,
79+
out,
80+
chunk_length=chunk_length,
81+
chunk_width=chunk_width,
82+
worker_processes=worker_processes,
83+
)
7884
ds2 = sg.load_dataset(out)
85+
# print()
86+
# print(ds.call_genotype.values[2])
87+
# print(ds2.call_genotype.values[2])
88+
7989
# print(ds2)
8090
# print(ds2.call_genotype.values)
8191
# print(ds.call_genotype.values)
8292
xt.assert_equal(ds, ds2)
8393
# TODO check array chunks
94+
95+
96+
@pytest.mark.parametrize(
97+
["chunk_length", "chunk_width"],
98+
[
99+
(10, 1),
100+
(10, 10),
101+
(33, 3),
102+
(99, 10),
103+
(3, 10),
104+
# This one doesn't fail as it's the same as defaults
105+
# (100, 10),
106+
],
107+
)
108+
@pytest.mark.parametrize("worker_processes", [0])
109+
def test_by_validating(tmp_path, chunk_length, chunk_width, worker_processes):
110+
path = "tests/data/plink/plink_sim_10s_100v_10pmiss.bed"
111+
out = tmp_path / "example.zarr"
112+
plink.convert(
113+
path,
114+
out,
115+
chunk_length=chunk_length,
116+
chunk_width=chunk_width,
117+
worker_processes=worker_processes,
118+
)
119+
plink.validate(path, out)

tests/test_simulated_data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,5 @@ def test_ploidy(self, ploidy, tmp_path):
4343
nt.assert_equal(ds.variant_position, ts.sites_position)
4444

4545

46-
# TODO add a plink equivalant if we can find a way of programatically
47-
# generating plink data?
46+
# TODO add a plink equivalant using
47+
# https://fastlmm.github.io/bed-reader/#bed_reader.to_bed

0 commit comments

Comments
 (0)