|
1 | 1 | import collections.abc
|
| 2 | +from typing import Optional |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import zarr |
| 6 | + |
| 7 | +from vcztools import filter as filter_mod |
| 8 | +from vcztools.regions import ( |
| 9 | + parse_regions, |
| 10 | + parse_targets, |
| 11 | + regions_to_chunk_indexes, |
| 12 | + regions_to_selection, |
| 13 | +) |
2 | 14 |
|
3 | 15 |
|
4 | 16 | # NOTE: this class is just a skeleton for now. The idea is that this
|
@@ -28,8 +40,175 @@ def __len__(self):
|
28 | 40 | def __getitem__(self, chunk):
|
29 | 41 | return {key: array.blocks[chunk] for key, array in self.arrays.items()}
|
30 | 42 |
|
| 43 | + def get_chunk_data(self, chunk, mask, samples_selection=None): |
| 44 | + num_samples = len(samples_selection) if samples_selection is not None else 0 |
| 45 | + return { |
| 46 | + key: get_vchunk_array( |
| 47 | + array, |
| 48 | + chunk, |
| 49 | + mask, |
| 50 | + samples_selection |
| 51 | + if (key.startswith("call_") and num_samples > 0) |
| 52 | + else None, |
| 53 | + ) |
| 54 | + for key, array in self.arrays.items() |
| 55 | + } |
| 56 | + |
| 57 | + |
| 58 | +def get_vchunk_array(zarray, v_chunk, mask, samples_selection=None): |
| 59 | + v_chunksize = zarray.chunks[0] |
| 60 | + start = v_chunksize * v_chunk |
| 61 | + end = v_chunksize * (v_chunk + 1) |
| 62 | + if samples_selection is None: |
| 63 | + result = zarray[start:end] |
| 64 | + else: |
| 65 | + result = zarray.oindex[start:end, samples_selection] |
| 66 | + if mask is not None: |
| 67 | + result = result[mask] |
| 68 | + return result |
| 69 | + |
| 70 | + |
| 71 | +def variant_chunk_index_iter(root, variant_regions=None, variant_targets=None): |
| 72 | + """Iterate over variant chunk indexes that overlap the given regions or targets. |
| 73 | +
|
| 74 | + Returns tuples of variant chunk indexes and (optional) variant masks. |
| 75 | +
|
| 76 | + A variant mask of None indicates that all the variants in the chunk are included. |
| 77 | + """ |
| 78 | + |
| 79 | + pos = root["variant_position"] |
| 80 | + |
| 81 | + if variant_regions is None and variant_targets is None: |
| 82 | + num_chunks = pos.cdata_shape[0] |
| 83 | + # no regions or targets selected |
| 84 | + for v_chunk in range(num_chunks): |
| 85 | + v_mask_chunk = None |
| 86 | + yield v_chunk, v_mask_chunk |
| 87 | + |
| 88 | + else: |
| 89 | + contigs_u = root["contig_id"][:].astype("U").tolist() |
| 90 | + regions = parse_regions(variant_regions, contigs_u) |
| 91 | + targets, complement = parse_targets(variant_targets, contigs_u) |
| 92 | + |
| 93 | + # Use the region index to find the chunks that overlap specfied regions or |
| 94 | + # targets |
| 95 | + region_index = root["region_index"][:] |
| 96 | + chunk_indexes = regions_to_chunk_indexes( |
| 97 | + regions, |
| 98 | + targets, |
| 99 | + complement, |
| 100 | + region_index, |
| 101 | + ) |
| 102 | + |
| 103 | + # Then use only load required variant_contig/position chunks |
| 104 | + if len(chunk_indexes) == 0: |
| 105 | + # no chunks - no variants to write |
| 106 | + return |
| 107 | + elif len(chunk_indexes) == 1: |
| 108 | + # single chunk |
| 109 | + block_sel = chunk_indexes[0] |
| 110 | + else: |
| 111 | + # zarr.blocks doesn't support int array indexing - use that when it does |
| 112 | + block_sel = slice(chunk_indexes[0], chunk_indexes[-1] + 1) |
| 113 | + |
| 114 | + region_variant_contig = root["variant_contig"].blocks[block_sel][:] |
| 115 | + region_variant_position = root["variant_position"].blocks[block_sel][:] |
| 116 | + region_variant_length = root["variant_length"].blocks[block_sel][:] |
| 117 | + |
| 118 | + # Find the final variant selection |
| 119 | + variant_selection = regions_to_selection( |
| 120 | + regions, |
| 121 | + targets, |
| 122 | + complement, |
| 123 | + region_variant_contig, |
| 124 | + region_variant_position, |
| 125 | + region_variant_length, |
| 126 | + ) |
| 127 | + variant_mask = np.zeros(region_variant_position.shape[0], dtype=bool) |
| 128 | + variant_mask[variant_selection] = 1 |
| 129 | + # Use zarr arrays to get mask chunks aligned with the main data |
| 130 | + # for convenience. |
| 131 | + z_variant_mask = zarr.array(variant_mask, chunks=pos.chunks[0]) |
| 132 | + |
| 133 | + for i, v_chunk in enumerate(chunk_indexes): |
| 134 | + v_mask_chunk = z_variant_mask.blocks[i] |
| 135 | + yield v_chunk, v_mask_chunk |
| 136 | + |
| 137 | + |
| 138 | +def variant_chunk_index_iter_with_filtering( |
| 139 | + root, |
| 140 | + *, |
| 141 | + variant_regions=None, |
| 142 | + variant_targets=None, |
| 143 | + include: Optional[str] = None, |
| 144 | + exclude: Optional[str] = None, |
| 145 | +): |
| 146 | + """Iterate over variant chunk indexes that overlap the given regions or targets |
| 147 | + and which match the include/exclude filter expression. |
| 148 | +
|
| 149 | + Returns tuples of variant chunk indexes and (optional) variant masks. |
| 150 | +
|
| 151 | + A variant mask of None indicates that all the variants in the chunk are included. |
| 152 | + """ |
| 153 | + |
| 154 | + filter_expr = filter_mod.FilterExpression( |
| 155 | + field_names=set(root), include=include, exclude=exclude |
| 156 | + ) |
| 157 | + if filter_expr.parse_result is None: |
| 158 | + filter_expr = None |
| 159 | + else: |
| 160 | + filter_fields = list(filter_expr.referenced_fields) |
| 161 | + filter_fields_reader = VariantChunkReader(root, fields=filter_fields) |
| 162 | + |
| 163 | + for v_chunk, v_mask_chunk in variant_chunk_index_iter( |
| 164 | + root, variant_regions, variant_targets |
| 165 | + ): |
| 166 | + if filter_expr is not None: |
| 167 | + chunk_data = filter_fields_reader[v_chunk] |
| 168 | + v_mask_chunk_filter = filter_expr.evaluate(chunk_data) |
| 169 | + if v_mask_chunk is None: |
| 170 | + v_mask_chunk = v_mask_chunk_filter |
| 171 | + else: |
| 172 | + if v_mask_chunk_filter.ndim == 2: |
| 173 | + v_mask_chunk = np.expand_dims(v_mask_chunk, axis=1) |
| 174 | + v_mask_chunk = np.logical_and(v_mask_chunk, v_mask_chunk_filter) |
| 175 | + if v_mask_chunk is None or np.any(v_mask_chunk): |
| 176 | + yield v_chunk, v_mask_chunk |
| 177 | + |
31 | 178 |
|
32 |
| -def variant_chunk_iter(root, fields=None, variant_select=None): |
33 |
| - chunk_reader = VariantChunkReader(root, fields=fields) |
34 |
| - for chunk in range(len(chunk_reader)): |
35 |
| - yield chunk_reader[chunk] |
| 179 | +def variant_chunk_iter( |
| 180 | + root, |
| 181 | + *, |
| 182 | + fields: Optional[list[str]] = None, |
| 183 | + variant_regions=None, |
| 184 | + variant_targets=None, |
| 185 | + include: Optional[str] = None, |
| 186 | + exclude: Optional[str] = None, |
| 187 | + samples_selection=None, |
| 188 | +): |
| 189 | + query_fields_reader = VariantChunkReader(root, fields=fields) |
| 190 | + for v_chunk, v_mask_chunk in variant_chunk_index_iter_with_filtering( |
| 191 | + root, |
| 192 | + variant_regions=variant_regions, |
| 193 | + variant_targets=variant_targets, |
| 194 | + include=include, |
| 195 | + exclude=exclude, |
| 196 | + ): |
| 197 | + # The variants_selection is used to subset variant chunks along |
| 198 | + # the variants dimension. |
| 199 | + # The call_mask is returned to the client to indicate which samples |
| 200 | + # matched (for each variant) in the case of per-sample filtering. |
| 201 | + if v_mask_chunk is None or v_mask_chunk.ndim == 1: |
| 202 | + variants_selection = v_mask_chunk |
| 203 | + call_mask = None |
| 204 | + else: |
| 205 | + variants_selection = np.any(v_mask_chunk, axis=1) |
| 206 | + call_mask = v_mask_chunk[variants_selection] |
| 207 | + if samples_selection is not None: |
| 208 | + call_mask = call_mask[:, samples_selection] |
| 209 | + chunk_data = query_fields_reader.get_chunk_data( |
| 210 | + v_chunk, variants_selection, samples_selection=samples_selection |
| 211 | + ) |
| 212 | + if call_mask is not None: |
| 213 | + chunk_data["call_mask"] = call_mask |
| 214 | + yield chunk_data |
0 commit comments