Skip to content

Commit ce78b2f

Browse files
tomwhitejeromekelleher
authored andcommitted
Use chunk_data rather than read from Zarr in query
1 parent 5729138 commit ce78b2f

File tree

2 files changed

+154
-132
lines changed

2 files changed

+154
-132
lines changed

tests/test_query.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import re
33
from io import StringIO
44

5+
import numpy as np
56
import pyparsing as pp
67
import pytest
78
import zarr
@@ -13,6 +14,7 @@
1314
list_samples,
1415
write_query,
1516
)
17+
from vcztools.retrieval import variant_chunk_iter
1618

1719

1820
def test_list_samples(tmp_path):
@@ -129,8 +131,44 @@ def root(self):
129131
],
130132
)
131133
def test(self, root, query_format, expected_result):
132-
generator = QueryFormatGenerator(query_format)
133-
result = "".join(generator(root))
134+
generator = QueryFormatGenerator(root, query_format)
135+
chunk_data = next(variant_chunk_iter(root))
136+
result = "".join(generator(chunk_data))
137+
assert result == expected_result
138+
139+
@pytest.mark.parametrize(
140+
("query_format", "call_mask", "expected_result"),
141+
[
142+
(
143+
r"[%DP ]\n",
144+
None,
145+
". . . \n. . . \n1 8 5 \n3 5 3 \n6 0 4 \n. 4 2 \n4 2 3 \n. . . \n. . . \n",
146+
),
147+
(
148+
r"[%DP ]\n",
149+
np.array(
150+
[
151+
[1, 1, 1,],
152+
[1, 1, 1,],
153+
[1, 0, 1,],
154+
[1, 1, 1,],
155+
[1, 1, 1,],
156+
[1, 1, 1,],
157+
[1, 1, 1,],
158+
[1, 1, 1,],
159+
[1, 1, 1,],
160+
]
161+
),
162+
". . . \n. . . \n1 5 \n3 5 3 \n6 0 4 \n. 4 2 \n4 2 3 \n. . . \n. . . \n",
163+
),
164+
],
165+
)
166+
def test_call_mask(self, root, query_format, call_mask, expected_result):
167+
generator = QueryFormatGenerator(root, query_format)
168+
chunk_data = next(variant_chunk_iter(root))
169+
if call_mask is not None:
170+
chunk_data["call_mask"] = call_mask
171+
result = "".join(generator(chunk_data))
134172
assert result == expected_result
135173

136174
@pytest.mark.parametrize(
@@ -140,8 +178,9 @@ def test(self, root, query_format, expected_result):
140178
def test_with_parse_results(self, root, query_format, expected_result):
141179
parser = QueryFormatParser()
142180
parse_results = parser(query_format)
143-
generator = QueryFormatGenerator(parse_results)
144-
result = "".join(generator(root))
181+
generator = QueryFormatGenerator(root, parse_results)
182+
chunk_data = next(variant_chunk_iter(root))
183+
result = "".join(generator(chunk_data))
145184
assert result == expected_result
146185

147186

vcztools/query.py

Lines changed: 111 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,12 @@ def __call__(self, *args, **kwargs):
5959

6060

6161
class QueryFormatGenerator:
62-
def __init__(self, query_format, filter_expr=None):
62+
def __init__(self, root, query_format, filter_expr=None):
63+
# TODO: pass in this metadata rather than root
64+
self.sample_ids = root["sample_id"][:].tolist()
65+
self.sample_count = len(self.sample_ids)
66+
self.contig_ids = root["contig_id"][:]
67+
self.filter_ids = root["filter_id"][:]
6368
if isinstance(query_format, str):
6469
parser = QueryFormatParser()
6570
parse_results = parser(query_format)
@@ -79,46 +84,37 @@ def __call__(self, *args, **kwargs):
7984
yield from self._generator(args[0])
8085

8186
def _compose_gt_generator(self) -> Callable:
82-
def generate(root):
83-
gt_zarray = root["call_genotype"]
84-
v_chunk_size = gt_zarray.chunks[0]
85-
86-
if "call_genotype_phased" in root:
87-
phase_zarray = root["call_genotype_phased"]
88-
assert gt_zarray.chunks[:2] == phase_zarray.chunks
89-
assert gt_zarray.shape[:2] == phase_zarray.shape
90-
91-
for v_chunk_index in range(gt_zarray.cdata_shape[0]):
92-
start = v_chunk_index * v_chunk_size
93-
end = start + v_chunk_size
94-
95-
for gt_row, phase in zip(
96-
gt_zarray[start:end], phase_zarray[start:end]
97-
):
98-
99-
def stringify(gt_and_phase: tuple):
100-
gt, phase = gt_and_phase
101-
gt = [
102-
str(allele) if allele != constants.INT_MISSING else "."
103-
for allele in gt
104-
if allele != constants.INT_FILL
105-
]
106-
separator = "|" if phase else "/"
107-
return separator.join(gt)
108-
109-
gt_row = gt_row.tolist()
110-
yield map(stringify, zip(gt_row, phase))
87+
def generate(chunk_data):
88+
gt_array = chunk_data["call_genotype"]
89+
90+
if "call_genotype_phased" in chunk_data:
91+
phase_array = chunk_data["call_genotype_phased"]
92+
assert gt_array.shape[:2] == phase_array.shape
93+
94+
for gt_row, phase in zip(gt_array, phase_array):
95+
96+
def stringify(gt_and_phase: tuple):
97+
gt, phase = gt_and_phase
98+
gt = [
99+
str(allele) if allele != constants.INT_MISSING else "."
100+
for allele in gt
101+
if allele != constants.INT_FILL
102+
]
103+
separator = "|" if phase else "/"
104+
return separator.join(gt)
105+
106+
gt_row = gt_row.tolist()
107+
yield map(stringify, zip(gt_row, phase))
111108
else:
112109
# TODO: Support datasets without the phasing data
113110
raise NotImplementedError
114111

115112
return generate
116113

117114
def _compose_sample_ids_generator(self) -> Callable:
118-
def generate(root):
119-
variant_count = root["variant_position"].shape[0]
120-
sample_ids = root["sample_id"][:].tolist()
121-
yield from itertools.repeat(sample_ids, variant_count)
115+
def generate(chunk_data):
116+
variant_count = chunk_data["variant_position"].shape[0]
117+
yield from itertools.repeat(self.sample_ids, variant_count)
122118

123119
return generate
124120

@@ -134,66 +130,49 @@ def _compose_tag_generator(
134130
if tag == "SAMPLE":
135131
return self._compose_sample_ids_generator()
136132

137-
def generate(root):
138-
vcz_names = set(root.keys())
133+
def generate(chunk_data):
134+
vcz_names = set(chunk_data.keys())
139135
vcz_name = vcf_name_to_vcz_name(vcz_names, tag)
140-
zarray = root[vcz_name]
141-
contig_ids = root["contig_id"][:] if tag == "CHROM" else None
142-
filter_ids = root["filter_id"][:] if tag == "FILTER" else None
143-
v_chunk_size = zarray.chunks[0]
144-
145-
for v_chunk_index in range(zarray.cdata_shape[0]):
146-
start = v_chunk_index * v_chunk_size
147-
end = start + v_chunk_size
148-
149-
for row in zarray[start:end]:
150-
is_missing = np.any(row == -1)
151-
152-
if tag == "CHROM":
153-
assert contig_ids is not None
154-
row = contig_ids[row]
155-
if tag == "REF":
156-
row = row[0]
157-
if tag == "ALT":
158-
row = [allele for allele in row[1:] if allele] or "."
159-
if tag == "FILTER":
160-
assert filter_ids is not None
161-
162-
if np.any(row):
163-
row = filter_ids[row]
164-
else:
165-
row = "."
166-
if tag == "QUAL":
167-
if math.isnan(row):
168-
row = "."
169-
else:
170-
row = f"{row:g}"
171-
if (
172-
not subfield
173-
and not sample_loop
174-
and (isinstance(row, np.ndarray) or isinstance(row, list))
175-
):
176-
row = ",".join(map(str, row))
177-
178-
if sample_loop:
179-
sample_count = root["sample_id"].shape[0]
180-
181-
if isinstance(row, np.ndarray):
182-
row = row.tolist()
183-
row = [
184-
(
185-
str(element)
186-
if element != constants.INT_MISSING
187-
else "."
188-
)
189-
for element in row
190-
if element != constants.INT_FILL
191-
]
192-
yield row
193-
else:
194-
yield itertools.repeat(str(row), sample_count)
136+
array = chunk_data[vcz_name]
137+
for row in array:
138+
is_missing = np.any(row == -1)
139+
140+
if tag == "CHROM":
141+
row = self.contig_ids[row]
142+
if tag == "REF":
143+
row = row[0]
144+
if tag == "ALT":
145+
row = [allele for allele in row[1:] if allele] or "."
146+
if tag == "FILTER":
147+
if np.any(row):
148+
row = self.filter_ids[row]
195149
else:
196-
yield row if not is_missing else "."
150+
row = "."
151+
if tag == "QUAL":
152+
if math.isnan(row):
153+
row = "."
154+
else:
155+
row = f"{row:g}"
156+
if (
157+
not subfield
158+
and not sample_loop
159+
and (isinstance(row, np.ndarray) or isinstance(row, list))
160+
):
161+
row = ",".join(map(str, row))
162+
163+
if sample_loop:
164+
if isinstance(row, np.ndarray):
165+
row = row.tolist()
166+
row = [
167+
(str(element) if element != constants.INT_MISSING else ".")
168+
for element in row
169+
if element != constants.INT_FILL
170+
]
171+
yield row
172+
else:
173+
yield itertools.repeat(str(row), self.sample_count)
174+
else:
175+
yield row if not is_missing else "."
197176

198177
return generate
199178

@@ -203,8 +182,8 @@ def _compose_subfield_generator(self, parse_results: pp.ParseResults) -> Callabl
203182
tag, subfield_index = parse_results
204183
tag_generator = self._compose_tag_generator(tag, subfield=True)
205184

206-
def generate(root):
207-
for tag in tag_generator(root):
185+
def generate(chunk_data):
186+
for tag in tag_generator(chunk_data):
208187
if isinstance(tag, str):
209188
assert tag == "."
210189
yield "."
@@ -224,18 +203,30 @@ def _compose_sample_loop_generator(
224203
parse_results,
225204
)
226205

227-
def generate(root):
228-
iterables = (generator(root) for generator in generators)
206+
def generate(chunk_data):
207+
iterables = (generator(chunk_data) for generator in generators)
229208
zipped = zip(*iterables)
230209
zipped_zipped = (zip(*element) for element in zipped)
231-
flattened_zipped_zipped = (
232-
(
233-
subsubelement
234-
for subelement in element # sample-wise
235-
for subsubelement in subelement
210+
if "call_mask" not in chunk_data:
211+
flattened_zipped_zipped = (
212+
(
213+
subsubelement
214+
for subelement in element # sample-wise
215+
for subsubelement in subelement
216+
)
217+
for element in zipped_zipped # variant-wise
218+
)
219+
else:
220+
call_mask = chunk_data["call_mask"]
221+
flattened_zipped_zipped = (
222+
(
223+
subsubelement
224+
for j, subelement in enumerate(element) # sample-wise
225+
if call_mask[i, j]
226+
for subsubelement in subelement
227+
)
228+
for i, element in enumerate(zipped_zipped) # variant-wise
236229
)
237-
for element in zipped_zipped # variant-wise
238-
)
239230
yield from map("".join, flattened_zipped_zipped)
240231

241232
return generate
@@ -255,29 +246,21 @@ def _compose_element_generator(
255246
return self._compose_tag_generator(element, sample_loop=sample_loop)
256247
else:
257248

258-
def generate(root):
249+
def generate(chunk_data):
259250
nonlocal element
260-
variant_count = root["variant_position"].shape[0]
251+
variant_count = chunk_data["variant_position"].shape[0]
261252
if sample_loop:
262-
sample_count = root["sample_id"].shape[0]
263253
for _ in range(variant_count):
264-
yield itertools.repeat(element, sample_count)
254+
yield itertools.repeat(element, self.sample_count)
265255
else:
266256
yield from itertools.repeat(element, variant_count)
267257

268258
return generate
269259

270260
def _compose_filter_generator(self, filter_expr):
271-
def generate(root):
272-
# NOTE: this should be done at the top-level when we've
273-
# figured out what fields need to be retrieved from both
274-
# the parsed query and filter expressions.
275-
reader = retrieval.VariantChunkReader(root)
276-
for v_chunk in range(root["variant_position"].cdata_shape[0]):
277-
# print("Read v_chunk", v_chunk)
278-
chunk_data = reader[v_chunk]
279-
v_chunk_select = filter_expr.evaluate(chunk_data)
280-
yield from v_chunk_select
261+
def generate(chunk_data):
262+
v_chunk_select = filter_expr.evaluate(chunk_data)
263+
yield from v_chunk_select
281264

282265
return generate
283266

@@ -289,15 +272,12 @@ def _compose_generator(
289272
generators = (
290273
self._compose_element_generator(element) for element in parse_results
291274
)
292-
filter_generator = self._compose_filter_generator(filter_expr)
293275

294-
def generate(root) -> str:
295-
iterables = (generator(root) for generator in generators)
296-
filter_iterable = filter_generator(root)
297-
for results, filter_indicator in zip(zip(*iterables), filter_iterable):
298-
if filter_indicator:
299-
results = map(str, results)
300-
yield "".join(results)
276+
def generate(chunk_data) -> str:
277+
iterables = (generator(chunk_data) for generator in generators)
278+
for results in zip(*iterables):
279+
results = map(str, results)
280+
yield "".join(results)
301281

302282
return generate
303283

@@ -314,7 +294,10 @@ def write_query(
314294
filter_expr = filter_mod.FilterExpression(
315295
field_names=set(root), include=include, exclude=exclude
316296
)
317-
generator = QueryFormatGenerator(query_format, filter_expr)
297+
generator = QueryFormatGenerator(root, query_format, filter_expr)
318298

319-
for result in generator(root):
320-
print(result, sep="", end="", file=output)
299+
for chunk_data in retrieval.variant_chunk_iter(
300+
root, include=include, exclude=exclude
301+
):
302+
for result in generator(chunk_data):
303+
print(result, sep="", end="", file=output)

0 commit comments

Comments
 (0)