Skip to content

Commit 41dfd9f

Browse files
tomwhitejeromekelleher
authored andcommitted
Retrieval iterator that takes regions, filter expressions, samples and returns variant chunks
1 parent ead8f3c commit 41dfd9f

File tree

1 file changed

+183
-4
lines changed

1 file changed

+183
-4
lines changed

vcztools/retrieval.py

Lines changed: 183 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,16 @@
11
import collections.abc
2+
from typing import Optional
3+
4+
import numpy as np
5+
import zarr
6+
7+
from vcztools import filter as filter_mod
8+
from vcztools.regions import (
9+
parse_regions,
10+
parse_targets,
11+
regions_to_chunk_indexes,
12+
regions_to_selection,
13+
)
214

315

416
# NOTE: this class is just a skeleton for now. The idea is that this
@@ -28,8 +40,175 @@ def __len__(self):
2840
def __getitem__(self, chunk):
2941
return {key: array.blocks[chunk] for key, array in self.arrays.items()}
3042

43+
def get_chunk_data(self, chunk, mask, samples_selection=None):
44+
num_samples = len(samples_selection) if samples_selection is not None else 0
45+
return {
46+
key: get_vchunk_array(
47+
array,
48+
chunk,
49+
mask,
50+
samples_selection
51+
if (key.startswith("call_") and num_samples > 0)
52+
else None,
53+
)
54+
for key, array in self.arrays.items()
55+
}
56+
57+
58+
def get_vchunk_array(zarray, v_chunk, mask, samples_selection=None):
59+
v_chunksize = zarray.chunks[0]
60+
start = v_chunksize * v_chunk
61+
end = v_chunksize * (v_chunk + 1)
62+
if samples_selection is None:
63+
result = zarray[start:end]
64+
else:
65+
result = zarray.oindex[start:end, samples_selection]
66+
if mask is not None:
67+
result = result[mask]
68+
return result
69+
70+
71+
def variant_chunk_index_iter(root, variant_regions=None, variant_targets=None):
72+
"""Iterate over variant chunk indexes that overlap the given regions or targets.
73+
74+
Returns tuples of variant chunk indexes and (optional) variant masks.
75+
76+
A variant mask of None indicates that all the variants in the chunk are included.
77+
"""
78+
79+
pos = root["variant_position"]
80+
81+
if variant_regions is None and variant_targets is None:
82+
num_chunks = pos.cdata_shape[0]
83+
# no regions or targets selected
84+
for v_chunk in range(num_chunks):
85+
v_mask_chunk = None
86+
yield v_chunk, v_mask_chunk
87+
88+
else:
89+
contigs_u = root["contig_id"][:].astype("U").tolist()
90+
regions = parse_regions(variant_regions, contigs_u)
91+
targets, complement = parse_targets(variant_targets, contigs_u)
92+
93+
# Use the region index to find the chunks that overlap specfied regions or
94+
# targets
95+
region_index = root["region_index"][:]
96+
chunk_indexes = regions_to_chunk_indexes(
97+
regions,
98+
targets,
99+
complement,
100+
region_index,
101+
)
102+
103+
# Then use only load required variant_contig/position chunks
104+
if len(chunk_indexes) == 0:
105+
# no chunks - no variants to write
106+
return
107+
elif len(chunk_indexes) == 1:
108+
# single chunk
109+
block_sel = chunk_indexes[0]
110+
else:
111+
# zarr.blocks doesn't support int array indexing - use that when it does
112+
block_sel = slice(chunk_indexes[0], chunk_indexes[-1] + 1)
113+
114+
region_variant_contig = root["variant_contig"].blocks[block_sel][:]
115+
region_variant_position = root["variant_position"].blocks[block_sel][:]
116+
region_variant_length = root["variant_length"].blocks[block_sel][:]
117+
118+
# Find the final variant selection
119+
variant_selection = regions_to_selection(
120+
regions,
121+
targets,
122+
complement,
123+
region_variant_contig,
124+
region_variant_position,
125+
region_variant_length,
126+
)
127+
variant_mask = np.zeros(region_variant_position.shape[0], dtype=bool)
128+
variant_mask[variant_selection] = 1
129+
# Use zarr arrays to get mask chunks aligned with the main data
130+
# for convenience.
131+
z_variant_mask = zarr.array(variant_mask, chunks=pos.chunks[0])
132+
133+
for i, v_chunk in enumerate(chunk_indexes):
134+
v_mask_chunk = z_variant_mask.blocks[i]
135+
yield v_chunk, v_mask_chunk
136+
137+
138+
def variant_chunk_index_iter_with_filtering(
139+
root,
140+
*,
141+
variant_regions=None,
142+
variant_targets=None,
143+
include: Optional[str] = None,
144+
exclude: Optional[str] = None,
145+
):
146+
"""Iterate over variant chunk indexes that overlap the given regions or targets
147+
and which match the include/exclude filter expression.
148+
149+
Returns tuples of variant chunk indexes and (optional) variant masks.
150+
151+
A variant mask of None indicates that all the variants in the chunk are included.
152+
"""
153+
154+
filter_expr = filter_mod.FilterExpression(
155+
field_names=set(root), include=include, exclude=exclude
156+
)
157+
if filter_expr.parse_result is None:
158+
filter_expr = None
159+
else:
160+
filter_fields = list(filter_expr.referenced_fields)
161+
filter_fields_reader = VariantChunkReader(root, fields=filter_fields)
162+
163+
for v_chunk, v_mask_chunk in variant_chunk_index_iter(
164+
root, variant_regions, variant_targets
165+
):
166+
if filter_expr is not None:
167+
chunk_data = filter_fields_reader[v_chunk]
168+
v_mask_chunk_filter = filter_expr.evaluate(chunk_data)
169+
if v_mask_chunk is None:
170+
v_mask_chunk = v_mask_chunk_filter
171+
else:
172+
if v_mask_chunk_filter.ndim == 2:
173+
v_mask_chunk = np.expand_dims(v_mask_chunk, axis=1)
174+
v_mask_chunk = np.logical_and(v_mask_chunk, v_mask_chunk_filter)
175+
if v_mask_chunk is None or np.any(v_mask_chunk):
176+
yield v_chunk, v_mask_chunk
177+
31178

32-
def variant_chunk_iter(root, fields=None, variant_select=None):
33-
chunk_reader = VariantChunkReader(root, fields=fields)
34-
for chunk in range(len(chunk_reader)):
35-
yield chunk_reader[chunk]
179+
def variant_chunk_iter(
180+
root,
181+
*,
182+
fields: Optional[list[str]] = None,
183+
variant_regions=None,
184+
variant_targets=None,
185+
include: Optional[str] = None,
186+
exclude: Optional[str] = None,
187+
samples_selection=None,
188+
):
189+
query_fields_reader = VariantChunkReader(root, fields=fields)
190+
for v_chunk, v_mask_chunk in variant_chunk_index_iter_with_filtering(
191+
root,
192+
variant_regions=variant_regions,
193+
variant_targets=variant_targets,
194+
include=include,
195+
exclude=exclude,
196+
):
197+
# The variants_selection is used to subset variant chunks along
198+
# the variants dimension.
199+
# The call_mask is returned to the client to indicate which samples
200+
# matched (for each variant) in the case of per-sample filtering.
201+
if v_mask_chunk is None or v_mask_chunk.ndim == 1:
202+
variants_selection = v_mask_chunk
203+
call_mask = None
204+
else:
205+
variants_selection = np.any(v_mask_chunk, axis=1)
206+
call_mask = v_mask_chunk[variants_selection]
207+
if samples_selection is not None:
208+
call_mask = call_mask[:, samples_selection]
209+
chunk_data = query_fields_reader.get_chunk_data(
210+
v_chunk, variants_selection, samples_selection=samples_selection
211+
)
212+
if call_mask is not None:
213+
chunk_data["call_mask"] = call_mask
214+
yield chunk_data

0 commit comments

Comments
 (0)