Skip to content

Commit b626dda

Browse files
Add buffered iterating over genotypes in plink
Was trying to read entire partition into memory previously
1 parent a076108 commit b626dda

File tree

3 files changed

+45
-10
lines changed

3 files changed

+45
-10
lines changed

bio2zarr/plink.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,18 @@ def __init__(self, path, num_variants, num_samples):
9696

9797
self.byte_lookup = lookup
9898

99+
def iter_decode(self, start, stop, buffer_size=None):
100+
"""
101+
Iterate of over the variants in the specified window
102+
with the specified approximate buffer size in bytes (default=10MiB).
103+
"""
104+
if buffer_size is None:
105+
buffer_size = 10 * 1024 * 1024
106+
variants_per_read = max(1, int(buffer_size / self.bytes_per_variant))
107+
for off in range(start, stop, variants_per_read):
108+
genotypes = self.decode(off, min(off + variants_per_read, stop))
109+
yield from genotypes
110+
99111
def decode(self, start, stop):
100112
chunk_size = stop - start
101113

@@ -108,6 +120,7 @@ def decode(self, start, stop):
108120
f"Reading {chunk_size} variants ({bytes_to_read} bytes) "
109121
f"from {self.path}"
110122
)
123+
111124
# TODO make it possible to read sequentially from the same file handle,
112125
# seeking only when necessary.
113126
with open(self.path, "rb") as f:
@@ -181,19 +194,16 @@ def iter_id(self, start, stop):
181194
yield from self.bim.variant_id[start:stop]
182195

183196
def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
184-
alt_field = self.bim.allele_1.values
185-
ref_field = self.bim.allele_2.values
186-
gt = self.bed_reader.decode(start, stop)
187-
phased = np.zeros(gt.shape[:2], dtype=bool)
188-
for i, (ref, alt) in enumerate(
189-
zip(ref_field[start:stop], alt_field[start:stop])
190-
):
197+
alt_iter = self.bim.allele_1.values[start:stop]
198+
ref_iter = self.bim.allele_2.values[start:stop]
199+
gt_iter = self.bed_reader.iter_decode(start, stop)
200+
for alt, ref, gt in zip(alt_iter, ref_iter, gt_iter):
191201
alleles = np.full(num_alleles, constants.STR_FILL, dtype="O")
192202
alleles[0] = ref
193203
alleles[1 : 1 + len(alt)] = alt
194-
204+
phased = np.zeros(gt.shape[0], dtype=bool)
195205
# rlen is the length of the REF in PLINK as there's no END annotations
196-
yield vcz.VariantData(len(alleles[0]), alleles, gt[i], phased[i])
206+
yield vcz.VariantData(len(alleles[0]), alleles, gt, phased)
197207

198208
def generate_schema(
199209
self,

bio2zarr/vcz.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -842,6 +842,7 @@ def encode_alleles_and_genotypes_partition(self, partition_index):
842842
partition_index, "call_genotype_phased"
843843
)
844844
shape = gt.buff.shape[1:]
845+
845846
for variant_data in self.source.iter_alleles_and_genotypes(
846847
partition.start, partition.stop, shape, alleles.array.shape[1]
847848
):

tests/test_plink.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ def test_generated_bed_files(self, tmp_path, num_variants, num_samples):
9393
data = np.arange(num_variants * num_samples, dtype=int) % 4
9494
data[data == 3] = -127
9595
data = data.reshape((num_variants, num_samples))
96-
9796
bed_reader.to_bed(bed_file, data.T, num_threads=1)
9897

9998
bytes_per_variant = (num_samples + 3) // 4
@@ -108,6 +107,31 @@ def test_generated_bed_files(self, tmp_path, num_variants, num_samples):
108107
for k in range(num_samples):
109108
assert br_map[data[j, k]] == tuple(g[j, k])
110109

110+
@pytest.mark.parametrize(
111+
("num_variants", "num_samples"),
112+
[
113+
(1, 1),
114+
(30, 3),
115+
(300, 1000),
116+
],
117+
)
118+
@pytest.mark.parametrize("buffer_size", [0, 1, 3, 100, 100_000, None])
119+
def test_iter_decode(self, tmp_path, buffer_size, num_variants, num_samples):
120+
bed_file = tmp_path / "a_file.bed"
121+
# Generate a regular pattern of all possible values
122+
data = np.arange(num_variants * num_samples, dtype=int) % 4
123+
data[data == 3] = -127
124+
data = data.reshape((num_variants, num_samples))
125+
bed_reader.to_bed(bed_file, data.T, num_threads=1)
126+
127+
reader = plink.BedReader(bed_file, num_variants, num_samples)
128+
G = reader.decode(0, num_variants)
129+
assert G.shape == (num_variants, num_samples, 2)
130+
131+
reader = plink.BedReader(bed_file, num_variants, num_samples)
132+
G2 = list(reader.iter_decode(0, num_variants, buffer_size=buffer_size))
133+
nt.assert_array_equal(G, G2)
134+
111135

112136
class TestSmallExample:
113137
@pytest.fixture(scope="class")

0 commit comments

Comments
 (0)