@@ -59,7 +59,12 @@ def __call__(self, *args, **kwargs):
59
59
60
60
61
61
class QueryFormatGenerator :
62
- def __init__ (self , query_format , filter_expr = None ):
62
+ def __init__ (self , root , query_format , filter_expr = None ):
63
+ # TODO: pass in this metadata rather than root
64
+ self .sample_ids = root ["sample_id" ][:].tolist ()
65
+ self .sample_count = len (self .sample_ids )
66
+ self .contig_ids = root ["contig_id" ][:]
67
+ self .filter_ids = root ["filter_id" ][:]
63
68
if isinstance (query_format , str ):
64
69
parser = QueryFormatParser ()
65
70
parse_results = parser (query_format )
@@ -79,46 +84,37 @@ def __call__(self, *args, **kwargs):
79
84
yield from self ._generator (args [0 ])
80
85
81
86
def _compose_gt_generator (self ) -> Callable :
82
- def generate (root ):
83
- gt_zarray = root ["call_genotype" ]
84
- v_chunk_size = gt_zarray .chunks [0 ]
85
-
86
- if "call_genotype_phased" in root :
87
- phase_zarray = root ["call_genotype_phased" ]
88
- assert gt_zarray .chunks [:2 ] == phase_zarray .chunks
89
- assert gt_zarray .shape [:2 ] == phase_zarray .shape
90
-
91
- for v_chunk_index in range (gt_zarray .cdata_shape [0 ]):
92
- start = v_chunk_index * v_chunk_size
93
- end = start + v_chunk_size
94
-
95
- for gt_row , phase in zip (
96
- gt_zarray [start :end ], phase_zarray [start :end ]
97
- ):
98
-
99
- def stringify (gt_and_phase : tuple ):
100
- gt , phase = gt_and_phase
101
- gt = [
102
- str (allele ) if allele != constants .INT_MISSING else "."
103
- for allele in gt
104
- if allele != constants .INT_FILL
105
- ]
106
- separator = "|" if phase else "/"
107
- return separator .join (gt )
108
-
109
- gt_row = gt_row .tolist ()
110
- yield map (stringify , zip (gt_row , phase ))
87
+ def generate (chunk_data ):
88
+ gt_array = chunk_data ["call_genotype" ]
89
+
90
+ if "call_genotype_phased" in chunk_data :
91
+ phase_array = chunk_data ["call_genotype_phased" ]
92
+ assert gt_array .shape [:2 ] == phase_array .shape
93
+
94
+ for gt_row , phase in zip (gt_array , phase_array ):
95
+
96
+ def stringify (gt_and_phase : tuple ):
97
+ gt , phase = gt_and_phase
98
+ gt = [
99
+ str (allele ) if allele != constants .INT_MISSING else "."
100
+ for allele in gt
101
+ if allele != constants .INT_FILL
102
+ ]
103
+ separator = "|" if phase else "/"
104
+ return separator .join (gt )
105
+
106
+ gt_row = gt_row .tolist ()
107
+ yield map (stringify , zip (gt_row , phase ))
111
108
else :
112
109
# TODO: Support datasets without the phasing data
113
110
raise NotImplementedError
114
111
115
112
return generate
116
113
117
114
def _compose_sample_ids_generator (self ) -> Callable :
118
- def generate (root ):
119
- variant_count = root ["variant_position" ].shape [0 ]
120
- sample_ids = root ["sample_id" ][:].tolist ()
121
- yield from itertools .repeat (sample_ids , variant_count )
115
+ def generate (chunk_data ):
116
+ variant_count = chunk_data ["variant_position" ].shape [0 ]
117
+ yield from itertools .repeat (self .sample_ids , variant_count )
122
118
123
119
return generate
124
120
@@ -134,66 +130,49 @@ def _compose_tag_generator(
134
130
if tag == "SAMPLE" :
135
131
return self ._compose_sample_ids_generator ()
136
132
137
- def generate (root ):
138
- vcz_names = set (root .keys ())
133
+ def generate (chunk_data ):
134
+ vcz_names = set (chunk_data .keys ())
139
135
vcz_name = vcf_name_to_vcz_name (vcz_names , tag )
140
- zarray = root [vcz_name ]
141
- contig_ids = root ["contig_id" ][:] if tag == "CHROM" else None
142
- filter_ids = root ["filter_id" ][:] if tag == "FILTER" else None
143
- v_chunk_size = zarray .chunks [0 ]
144
-
145
- for v_chunk_index in range (zarray .cdata_shape [0 ]):
146
- start = v_chunk_index * v_chunk_size
147
- end = start + v_chunk_size
148
-
149
- for row in zarray [start :end ]:
150
- is_missing = np .any (row == - 1 )
151
-
152
- if tag == "CHROM" :
153
- assert contig_ids is not None
154
- row = contig_ids [row ]
155
- if tag == "REF" :
156
- row = row [0 ]
157
- if tag == "ALT" :
158
- row = [allele for allele in row [1 :] if allele ] or "."
159
- if tag == "FILTER" :
160
- assert filter_ids is not None
161
-
162
- if np .any (row ):
163
- row = filter_ids [row ]
164
- else :
165
- row = "."
166
- if tag == "QUAL" :
167
- if math .isnan (row ):
168
- row = "."
169
- else :
170
- row = f"{ row :g} "
171
- if (
172
- not subfield
173
- and not sample_loop
174
- and (isinstance (row , np .ndarray ) or isinstance (row , list ))
175
- ):
176
- row = "," .join (map (str , row ))
177
-
178
- if sample_loop :
179
- sample_count = root ["sample_id" ].shape [0 ]
180
-
181
- if isinstance (row , np .ndarray ):
182
- row = row .tolist ()
183
- row = [
184
- (
185
- str (element )
186
- if element != constants .INT_MISSING
187
- else "."
188
- )
189
- for element in row
190
- if element != constants .INT_FILL
191
- ]
192
- yield row
193
- else :
194
- yield itertools .repeat (str (row ), sample_count )
136
+ array = chunk_data [vcz_name ]
137
+ for row in array :
138
+ is_missing = np .any (row == - 1 )
139
+
140
+ if tag == "CHROM" :
141
+ row = self .contig_ids [row ]
142
+ if tag == "REF" :
143
+ row = row [0 ]
144
+ if tag == "ALT" :
145
+ row = [allele for allele in row [1 :] if allele ] or "."
146
+ if tag == "FILTER" :
147
+ if np .any (row ):
148
+ row = self .filter_ids [row ]
195
149
else :
196
- yield row if not is_missing else "."
150
+ row = "."
151
+ if tag == "QUAL" :
152
+ if math .isnan (row ):
153
+ row = "."
154
+ else :
155
+ row = f"{ row :g} "
156
+ if (
157
+ not subfield
158
+ and not sample_loop
159
+ and (isinstance (row , np .ndarray ) or isinstance (row , list ))
160
+ ):
161
+ row = "," .join (map (str , row ))
162
+
163
+ if sample_loop :
164
+ if isinstance (row , np .ndarray ):
165
+ row = row .tolist ()
166
+ row = [
167
+ (str (element ) if element != constants .INT_MISSING else "." )
168
+ for element in row
169
+ if element != constants .INT_FILL
170
+ ]
171
+ yield row
172
+ else :
173
+ yield itertools .repeat (str (row ), self .sample_count )
174
+ else :
175
+ yield row if not is_missing else "."
197
176
198
177
return generate
199
178
@@ -203,8 +182,8 @@ def _compose_subfield_generator(self, parse_results: pp.ParseResults) -> Callabl
203
182
tag , subfield_index = parse_results
204
183
tag_generator = self ._compose_tag_generator (tag , subfield = True )
205
184
206
- def generate (root ):
207
- for tag in tag_generator (root ):
185
+ def generate (chunk_data ):
186
+ for tag in tag_generator (chunk_data ):
208
187
if isinstance (tag , str ):
209
188
assert tag == "."
210
189
yield "."
@@ -224,18 +203,30 @@ def _compose_sample_loop_generator(
224
203
parse_results ,
225
204
)
226
205
227
- def generate (root ):
228
- iterables = (generator (root ) for generator in generators )
206
+ def generate (chunk_data ):
207
+ iterables = (generator (chunk_data ) for generator in generators )
229
208
zipped = zip (* iterables )
230
209
zipped_zipped = (zip (* element ) for element in zipped )
231
- flattened_zipped_zipped = (
232
- (
233
- subsubelement
234
- for subelement in element # sample-wise
235
- for subsubelement in subelement
210
+ if "call_mask" not in chunk_data :
211
+ flattened_zipped_zipped = (
212
+ (
213
+ subsubelement
214
+ for subelement in element # sample-wise
215
+ for subsubelement in subelement
216
+ )
217
+ for element in zipped_zipped # variant-wise
218
+ )
219
+ else :
220
+ call_mask = chunk_data ["call_mask" ]
221
+ flattened_zipped_zipped = (
222
+ (
223
+ subsubelement
224
+ for j , subelement in enumerate (element ) # sample-wise
225
+ if call_mask [i , j ]
226
+ for subsubelement in subelement
227
+ )
228
+ for i , element in enumerate (zipped_zipped ) # variant-wise
236
229
)
237
- for element in zipped_zipped # variant-wise
238
- )
239
230
yield from map ("" .join , flattened_zipped_zipped )
240
231
241
232
return generate
@@ -255,29 +246,21 @@ def _compose_element_generator(
255
246
return self ._compose_tag_generator (element , sample_loop = sample_loop )
256
247
else :
257
248
258
- def generate (root ):
249
+ def generate (chunk_data ):
259
250
nonlocal element
260
- variant_count = root ["variant_position" ].shape [0 ]
251
+ variant_count = chunk_data ["variant_position" ].shape [0 ]
261
252
if sample_loop :
262
- sample_count = root ["sample_id" ].shape [0 ]
263
253
for _ in range (variant_count ):
264
- yield itertools .repeat (element , sample_count )
254
+ yield itertools .repeat (element , self . sample_count )
265
255
else :
266
256
yield from itertools .repeat (element , variant_count )
267
257
268
258
return generate
269
259
270
260
def _compose_filter_generator (self , filter_expr ):
271
- def generate (root ):
272
- # NOTE: this should be done at the top-level when we've
273
- # figured out what fields need to be retrieved from both
274
- # the parsed query and filter expressions.
275
- reader = retrieval .VariantChunkReader (root )
276
- for v_chunk in range (root ["variant_position" ].cdata_shape [0 ]):
277
- # print("Read v_chunk", v_chunk)
278
- chunk_data = reader [v_chunk ]
279
- v_chunk_select = filter_expr .evaluate (chunk_data )
280
- yield from v_chunk_select
261
+ def generate (chunk_data ):
262
+ v_chunk_select = filter_expr .evaluate (chunk_data )
263
+ yield from v_chunk_select
281
264
282
265
return generate
283
266
@@ -289,15 +272,12 @@ def _compose_generator(
289
272
generators = (
290
273
self ._compose_element_generator (element ) for element in parse_results
291
274
)
292
- filter_generator = self ._compose_filter_generator (filter_expr )
293
275
294
- def generate (root ) -> str :
295
- iterables = (generator (root ) for generator in generators )
296
- filter_iterable = filter_generator (root )
297
- for results , filter_indicator in zip (zip (* iterables ), filter_iterable ):
298
- if filter_indicator :
299
- results = map (str , results )
300
- yield "" .join (results )
276
+ def generate (chunk_data ) -> str :
277
+ iterables = (generator (chunk_data ) for generator in generators )
278
+ for results in zip (* iterables ):
279
+ results = map (str , results )
280
+ yield "" .join (results )
301
281
302
282
return generate
303
283
@@ -314,7 +294,10 @@ def write_query(
314
294
filter_expr = filter_mod .FilterExpression (
315
295
field_names = set (root ), include = include , exclude = exclude
316
296
)
317
- generator = QueryFormatGenerator (query_format , filter_expr )
297
+ generator = QueryFormatGenerator (root , query_format , filter_expr )
318
298
319
- for result in generator (root ):
320
- print (result , sep = "" , end = "" , file = output )
299
+ for chunk_data in retrieval .variant_chunk_iter (
300
+ root , include = include , exclude = exclude
301
+ ):
302
+ for result in generator (chunk_data ):
303
+ print (result , sep = "" , end = "" , file = output )
0 commit comments