Skip to content

Commit 413a963

Browse files
Fix up basic plink genotype tests
Closes #26
1 parent 95a265f commit 413a963

File tree

4 files changed

+122
-49
lines changed

4 files changed

+122
-49
lines changed

bio2zarr/core.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ def next_buffer_row(self):
8888
return row
8989

9090
def flush(self):
91-
# TODO just move sync_flush_array in here
9291
if self.buffer_row != 0:
9392
if len(self.array.chunks) <= 1:
9493
sync_flush_1d_array(

bio2zarr/plink.py

Lines changed: 79 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,51 @@
1+
import logging
2+
13
import numpy as np
24
import zarr
35
import bed_reader
46

57
from . import core
68

79

8-
def encode_bed_partition_genotypes(
9-
bed_path, zarr_path, start_variant, end_variant, encoder_threads=8
10-
):
11-
bed = bed_reader.open_bed(bed_path, num_threads=1)
10+
logger = logging.getLogger(__name__)
11+
1212

13+
def encode_genotypes_slice(bed_path, zarr_path, start, stop):
14+
bed = bed_reader.open_bed(bed_path, num_threads=1)
1315
store = zarr.DirectoryStore(zarr_path)
1416
root = zarr.group(store=store)
15-
gt = core.BufferedArray(root["call_genotype"])
16-
gt_mask = core.BufferedArray(root["call_genotype_mask"])
17-
gt_phased = core.BufferedArray(root["call_genotype_phased"])
17+
gt = core.BufferedArray(root["call_genotype"], start)
18+
gt_mask = core.BufferedArray(root["call_genotype_mask"], start)
19+
gt_phased = core.BufferedArray(root["call_genotype_phased"], start)
1820
chunk_length = gt.array.chunks[0]
19-
assert start_variant % chunk_length == 0
20-
21-
buffered_arrays = [gt, gt_phased, gt_mask]
22-
23-
with core.ThreadedZarrEncoder(buffered_arrays, encoder_threads) as te:
24-
start = start_variant
25-
while start < end_variant:
26-
stop = min(start + chunk_length, end_variant)
27-
bed_chunk = bed.read(index=slice(start, stop), dtype="int8").T
28-
# Note could do this without iterating over rows, but it's a bit
29-
# simpler and the bottleneck is in the encoding step anyway. It's
30-
# also nice to have updates on the progress monitor.
31-
for values in bed_chunk:
32-
j = te.next_buffer_row()
33-
dest = gt.buff[j]
34-
dest[values == -127] = -1
35-
dest[values == 2] = 1
36-
dest[values == 1, 0] = 1
37-
gt_phased.buff[j] = False
38-
gt_mask.buff[j] = dest == -1
39-
core.update_progress(1)
40-
start = stop
21+
n = gt.array.shape[1]
22+
assert start % chunk_length == 0
23+
24+
B = bed.read(dtype=np.int8).T
25+
26+
chunk_start = start
27+
while chunk_start < stop:
28+
chunk_stop = min(chunk_start + chunk_length, stop)
29+
bed_chunk = bed.read(index=np.s_[:, chunk_start:chunk_stop], dtype=np.int8).T
30+
# Probably should do this without iterating over rows, but it's a bit
31+
# simpler and lines up better with the array buffering API. The bottleneck
32+
# is in the encoding anyway.
33+
for values in bed_chunk:
34+
j = gt.next_buffer_row()
35+
g = np.zeros_like(gt.buff[j])
36+
g[values == -127] = -1
37+
g[values == 2] = 1
38+
g[values == 1, 0] = 1
39+
gt.buff[j] = g
40+
j = gt_phased.next_buffer_row()
41+
gt_phased.buff[j] = False
42+
j = gt_mask.next_buffer_row()
43+
gt_mask.buff[j] = gt.buff[j] == -1
44+
chunk_start = chunk_stop
45+
gt.flush()
46+
gt_phased.flush()
47+
gt_mask.flush()
48+
logger.debug(f"GT slice {start}:{stop} done")
4149

4250

4351
def convert(
@@ -81,7 +89,7 @@ def convert(
8189
dimensions += ["ploidy"]
8290
a = root.empty(
8391
"call_genotype",
84-
dtype="i8",
92+
dtype="i1",
8593
shape=list(shape),
8694
chunks=list(chunks),
8795
compressor=core.default_compressor,
@@ -97,22 +105,52 @@ def convert(
97105
)
98106
a.attrs["_ARRAY_DIMENSIONS"] = list(dimensions)
99107

100-
chunks_per_future = 2 # FIXME - make a parameter
101-
start = 0
102-
partitions = []
103-
while start < m:
104-
stop = min(m, start + chunk_length * chunks_per_future)
105-
partitions.append((start, stop))
106-
start = stop
107-
assert start == m
108+
num_slices = max(1, worker_processes * 4)
109+
slices = core.chunk_aligned_slices(a, num_slices)
110+
111+
total_chunks = sum(a.nchunks for a in root.values())
108112

109113
progress_config = core.ProgressConfig(
110-
total=m, title="Convert", units="vars", show=show_progress
114+
total=total_chunks, title="Convert", units="chunks", show=show_progress
111115
)
112116
with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
113-
for start, end in partitions:
114-
pwm.submit(encode_bed_partition_genotypes, bed_path, zarr_path, start, end)
117+
for start, stop in slices:
118+
pwm.submit(encode_genotypes_slice, bed_path, zarr_path, start, stop)
115119

116120
# TODO also add atomic swap like VCF. Should be abstracted to
117121
# share basic code for setting up the variation dataset zarr
118122
zarr.consolidate_metadata(zarr_path)
123+
124+
125+
# FIXME do this more efficiently - currently reading the whole thing
126+
# in for convenience, and also comparing call-by-call
127+
def validate(bed_path, zarr_path):
128+
store = zarr.DirectoryStore(zarr_path)
129+
root = zarr.group(store=store)
130+
call_genotype = root["call_genotype"][:]
131+
132+
bed = bed_reader.open_bed(bed_path, num_threads=1)
133+
134+
assert call_genotype.shape[0] == bed.sid_count
135+
assert call_genotype.shape[1] == bed.iid_count
136+
bed_genotypes = bed.read(dtype="int8").T
137+
assert call_genotype.shape[0] == bed_genotypes.shape[0]
138+
assert call_genotype.shape[1] == bed_genotypes.shape[1]
139+
assert call_genotype.shape[2] == 2
140+
141+
row_id = 0
142+
for bed_row, zarr_row in zip(bed_genotypes, call_genotype):
143+
# print("ROW", row_id)
144+
# print(bed_row, zarr_row)
145+
row_id += 1
146+
for bed_call, zarr_call in zip(bed_row, zarr_row):
147+
if bed_call == -127:
148+
assert list(zarr_call) == [-1, -1]
149+
elif bed_call == 0:
150+
assert list(zarr_call) == [0, 0]
151+
elif bed_call == 1:
152+
assert list(zarr_call) == [1, 0]
153+
elif bed_call == 2:
154+
assert list(zarr_call) == [1, 1]
155+
else: # pragma no cover
156+
assert False

tests/test_plink.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@ class TestSmallExample:
1212
@pytest.fixture(scope="class")
1313
def ds(self, tmp_path_factory):
1414
path = "tests/data/plink/plink_sim_10s_100v_10pmiss.bed"
15-
out = tmp_path_factory.mktemp("data") / "example.vcf.zarr"
15+
out = tmp_path_factory.mktemp("data") / "example.plink.zarr"
1616
plink.convert(path, out)
1717
return sg.load_dataset(out)
1818

19-
2019
@pytest.mark.xfail
20+
# FIXME I'm not sure these are the correct genotypes here, at least
21+
# the test isn't passing and others are
2122
def test_genotypes(self, ds):
2223
# Validate a few randomly selected individual calls
2324
# (spanning all possible states for a call)
@@ -56,7 +57,7 @@ def test_genotypes(self, ds):
5657
# FIXME not working
5758
nt.assert_array_equal(actual, expected)
5859

59-
@pytest.mark.xfail
60+
# @pytest.mark.xfail
6061
@pytest.mark.parametrize(
6162
["chunk_length", "chunk_width"],
6263
[
@@ -73,11 +74,46 @@ def test_chunk_size(
7374
):
7475
path = "tests/data/plink/plink_sim_10s_100v_10pmiss.bed"
7576
out = tmp_path / "example.zarr"
76-
plink.convert(path, out, chunk_length=chunk_length, chunk_width=chunk_width,
77-
worker_processes=worker_processes)
77+
plink.convert(
78+
path,
79+
out,
80+
chunk_length=chunk_length,
81+
chunk_width=chunk_width,
82+
worker_processes=worker_processes,
83+
)
7884
ds2 = sg.load_dataset(out)
85+
# print()
86+
# print(ds.call_genotype.values[2])
87+
# print(ds2.call_genotype.values[2])
88+
7989
# print(ds2)
8090
# print(ds2.call_genotype.values)
8191
# print(ds.call_genotype.values)
8292
xt.assert_equal(ds, ds2)
8393
# TODO check array chunks
94+
95+
96+
@pytest.mark.parametrize(
97+
["chunk_length", "chunk_width"],
98+
[
99+
(10, 1),
100+
(10, 10),
101+
(33, 3),
102+
(99, 10),
103+
(3, 10),
104+
# This one doesn't fail as it's the same as defaults
105+
# (100, 10),
106+
],
107+
)
108+
@pytest.mark.parametrize("worker_processes", [0])
109+
def test_by_validating(tmp_path, chunk_length, chunk_width, worker_processes):
110+
path = "tests/data/plink/plink_sim_10s_100v_10pmiss.bed"
111+
out = tmp_path / "example.zarr"
112+
plink.convert(
113+
path,
114+
out,
115+
chunk_length=chunk_length,
116+
chunk_width=chunk_width,
117+
worker_processes=worker_processes,
118+
)
119+
plink.validate(path, out)

tests/test_simulated_data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,5 @@ def test_ploidy(self, ploidy, tmp_path):
4343
nt.assert_equal(ds.variant_position, ts.sites_position)
4444

4545

46-
# TODO add a plink equivalant if we can find a way of programatically
47-
# generating plink data?
46+
# TODO add a plink equivalant using
47+
# https://fastlmm.github.io/bed-reader/#bed_reader.to_bed

0 commit comments

Comments
 (0)